Commit 5988d2cc authored by yuguo960516's avatar yuguo960516
Browse files

bert-large

parent 478602ba
Pipeline #142 canceled with stages
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from collections import defaultdict
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple
import numpy as np
import oneflow as flow
from oneflow import nn
from termcolor import colored
import libai.utils.distributed as dist
from libai.utils.file_io import HTTPURLHandler, PathManagerBase
class _IncompatibleKeys(
NamedTuple(
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
("incorrect_shapes", List[Tuple]),
],
)
):
pass
class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
# NOTE: only support data_parallel for saving model
# TODO: save model: support model_parallel and pipeline parallel
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
# Default PathManager, support HTTP URLs
# A user may want to use a different project-specific PathManagerBase'
self.path_manager: PathManagerBase = PathManagerBase()
self.path_manager.register_handler(HTTPURLHandler())
def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = name
save_dir = os.path.join(self.save_dir, basename)
assert os.path.basename(save_dir) == basename, basename
if not self.path_manager.exists(save_dir):
self.path_manager.mkdirs(save_dir)
self.logger.info("Saving checkpoint to {}".format(save_dir))
for save_name in data:
if save_name == "iteration":
continue
save_file = os.path.join(save_dir, save_name)
# If directory existing, remove it for saving
if self.path_manager.exists(save_file):
self.path_manager.mkdirs(save_file)
flow.save(data[save_name], save_file, global_dst_rank=0)
if basename != "model_best":
self.tag_last_checkpoint(basename)
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info("No checkpoint found. Training model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(path))
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if incompatible is not None: # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint: # pyre-ignore
self.logger.info("Loading {} from {}".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return self.path_manager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
# load checkpoint file in rank0
if flow.env.get_rank() == 0:
with open(save_file, "r") as f:
last_saved = f.read().strip()
else:
last_saved = None
# broadcast checkpoint file to other ranks
last_saved = dist.broadcast_py_object(last_saved, src=0)
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint (if exists). Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with self.path_manager.open(save_file, "w") as f:
f.write(last_filename_basename) # pyre-ignore
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to flow.Tensor or numpy arrays.
"""
data = {}
keys = self.path_manager.ls(f)
# broadcast checkpointer keys to other ranks
keys = dist.broadcast_py_object(keys, src=0)
for key in keys:
data[key] = flow.load(os.path.join(f, key), global_src_rank=0)
try:
data["iter"] = int(f.split("_")[-1])
except: # noqa
self.logger.info(f"iter info in {f} not found, set iter to 0")
data["iter"] = 0
return data
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
missing_keys = _filter_reused_missing_keys(self.model, incompatible.missing_keys)
if missing_keys:
self.logger.info(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.info(get_unexpected_parameters_message(incompatible.unexpected_keys))
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to flow tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by oneflow's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(v, flow.Tensor):
raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
# If it's local tensor, convert it to global tensor.
if not v.is_global:
if k in self.model.state_dict():
model_v = self.model.state_dict()[k]
state_dict[k] = v.to_global(sbp=model_v.sbp, placement=model_v.placement)
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(
self,
checkpointer: Checkpointer,
period: int,
max_iter: Optional[int] = None,
max_to_keep: Optional[int] = None,
file_prefix: str = "model",
):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_epoch (int): maximum number of epochs. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
if max_to_keep is not None:
assert max_to_keep > 0
self.max_to_keep = max_to_keep
self.recent_checkpoints: List[str] = []
self.file_prefix = file_prefix
self.path_manager: PathManagerBase = checkpointer.path_manager
def step(self, iteration: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current epoch, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"{}_{:07d}".format(self.file_prefix, iteration), **additional_state
)
if self.max_to_keep is not None:
self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file())
if len(self.recent_checkpoints) > self.max_to_keep:
file_to_delete = self.recent_checkpoints.pop(0)
if self.path_manager.exists(file_to_delete) and not file_to_delete.endswith(
"{}_{:07d}".format(self.file_prefix, iteration)
):
self.path_manager.rm(file_to_delete)
if self.max_iter is not None:
if iteration >= self.max_iter - 1:
self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False) # pyre-ignore
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items())
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items())
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix) :]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1 :]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(model: nn.Module, prefix: str = "") -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import dill
import numpy as np
import oneflow as flow
from omegaconf import OmegaConf
from libai.config import try_get_key
logger = logging.getLogger(__name__)
_DIST_UTIL = None
def _merge_devices(devices):
num_gpus_per_node = get_world_size() // get_num_nodes()
node_devices = [node_id * num_gpus_per_node + device_id for node_id, device_id in devices]
return node_devices
class _DistributeUtil(object):
def __init__(self, cfg):
self._init_distributed_env(cfg)
self._init_parallel_size(cfg)
self._init_placement_group(cfg)
self._init_parallel_hierarchy()
def _init_distributed_env(self, cfg):
"""Initialize the distributed environment."""
num_nodes = get_num_nodes()
num_gpus_per_node = get_world_size() // num_nodes
if try_get_key(cfg, "num_gpus_per_node", default=num_gpus_per_node) != num_gpus_per_node:
# This means key(num_gpus_per_node) saved in config is not equal
# to environment variable.
# Give user a warning about inconsistent reproduce environment.
logger.warning(
"'train.dist.num_gpus_per_node' are not equal to environment variable. "
f"{cfg.num_gpus_per_node} != {num_gpus_per_node}"
)
if try_get_key(cfg, "num_nodes", default=num_nodes) != num_nodes:
logger.warning(
"'train.dist.num_nodes' are not equal to"
f"environment variable. {cfg.num_nodes} != {num_nodes}"
)
# Set the actual value to config
cfg.num_nodes = num_nodes
cfg.num_gpus_per_node = num_gpus_per_node
self._num_nodes = num_nodes
self._num_gpus_per_node = num_gpus_per_node
self._world_size = num_gpus_per_node * num_nodes
# Add set device type
self._device_type = try_get_key(cfg, "device_type", default="cuda")
def _init_parallel_size(self, cfg):
# tensor parallel size
self._tensor_parallel_size = min(cfg.tensor_parallel_size, self.world_size)
assert self.world_size % self._tensor_parallel_size == 0, (
f"world size ({self.world_size}) is not divisible by"
f" tensor parallel size ({self._tensor_parallel_size})"
)
# Set the actual tensor parallel size to cfg
cfg.tensor_parallel_size = self._tensor_parallel_size
# pipeline parallel size
self._pipeline_parallel_size = min(
cfg.pipeline_parallel_size, self.world_size // cfg.tensor_parallel_size
)
# Set the actual pipeline parallel size to cfg
cfg.pipeline_parallel_size = self._pipeline_parallel_size
if cfg.pipeline_parallel_size > 1:
assert (
try_get_key(cfg, "pipeline_num_layers") is not None
), "cfg.train.dist.pipeline_num_layers must be set when run pipeline parallel"
assert cfg.pipeline_num_layers >= self._pipeline_parallel_size, (
f"number of layers ({cfg.pipeline_num_layers}) is less than"
f" pipeline model parallel size ({self._pipeline_parallel_size})"
)
if try_get_key(cfg, "custom_pipeline_stage_id") is not None:
assert OmegaConf.is_list(
cfg.custom_pipeline_stage_id
), "type of cfg.train.dist.custom_pipeline_stage_id must be list"
cfg.custom_pipeline_stage_id = list(cfg.custom_pipeline_stage_id)
assert max(cfg.custom_pipeline_stage_id) < self._world_size, (
f"the element {max(cfg.custom_pipeline_stage_id)} in"
" cfg.train.dist.custom_pipeline_stage_id is out of range"
f" for total rank {self._world_size}"
)
assert len(cfg.custom_pipeline_stage_id) == cfg.pipeline_num_layers, (
"the length of cfg.train.dist.custom_pipeline_stage_id"
f" {len(cfg.custom_pipeline_stage_id)} must be equal to"
" cfg.train.dist.pipeline_num_layers"
f" {cfg.train.dist.pipeline_num_layers}"
)
else:
# no pipeline parallel, just set 10000
if try_get_key(cfg, "pipeline_num_layers") is None:
cfg.pipeline_num_layers = 10000
self._model_parallel_size = self._pipeline_parallel_size * self._tensor_parallel_size
assert self.world_size % self._model_parallel_size == 0, (
f"world size ({self.world_size}) is not divisible by"
f" tensor model parallel size ({self._tensor_parallel_size}) times"
f" pipeline model parallel size ({self._pipeline_parallel_size})"
)
# data parallel size
self._data_parallel_size = self.world_size // self._model_parallel_size
# Set the actual data parallel size to cfg
cfg.data_parallel_size = self._data_parallel_size
def _init_placement_group(self, cfg):
node_ids = [i // self.num_gpus_per_node for i in range(self.world_size)]
device_ids = list(range(self.num_gpus_per_node)) * self.num_nodes
# [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]
devices = [(n, d) for n, d in zip(node_ids, device_ids)]
num_devices_per_stage = self.world_size // self._pipeline_parallel_size
stages_devices = [
_merge_devices(devices[i : (i + num_devices_per_stage)])
for i in range(0, self.world_size, num_devices_per_stage)
]
# change pipeline_num_layers to make the middle stages contain more layers
if (
self._pipeline_parallel_size >= 4
and cfg.pipeline_num_layers >= 8
and cfg.pipeline_num_layers % self._pipeline_parallel_size == 0
):
temp_num_layers_per_stage = cfg.pipeline_num_layers // self._pipeline_parallel_size
actual_pipeline_num_layers = cfg.pipeline_num_layers + min(
self._pipeline_parallel_size - 1, temp_num_layers_per_stage
)
else:
actual_pipeline_num_layers = cfg.pipeline_num_layers
num_layers_per_stage = actual_pipeline_num_layers // self._pipeline_parallel_size
stage_offset = actual_pipeline_num_layers % self._pipeline_parallel_size
# stage_offset can make the later stages contain more layers when pipeline_num_layers
# cannot be divided by pipeline_parallel_size.
# This can make pipeline parallel more memory efficient.
self._layer_stage_ids = []
for i in range(0, actual_pipeline_num_layers - stage_offset, num_layers_per_stage):
stage_id = i // num_layers_per_stage
if stage_id >= (self._pipeline_parallel_size - stage_offset):
self._layer_stage_ids.append(stage_id)
self._layer_stage_ids.extend([stage_id] * num_layers_per_stage)
self._layer_stage_ids = self._layer_stage_ids[: cfg.pipeline_num_layers]
# when pipeline_parallel_size > 1, we add pipeline_stage_id infomation into cfg
if cfg.pipeline_parallel_size > 1:
cfg.auto_pipeline_stage_id = self._layer_stage_ids
# set pipeline_stage_id by users' setting
if try_get_key(cfg, "custom_pipeline_stage_id") is not None:
self._layer_stage_ids = cfg.custom_pipeline_stage_id
cfg.actual_pipeline_stage_id = self._layer_stage_ids
self._layer_ranks = [stages_devices[stage_id] for stage_id in self._layer_stage_ids]
def _init_parallel_hierarchy(self):
if self.is_data_model_parallel():
self._parallel_hierarchy = (
self._data_parallel_size,
self._tensor_parallel_size,
)
else:
self._parallel_hierarchy = None
@property
def num_nodes(self):
return self._num_nodes
@property
def num_gpus_per_node(self):
return self._num_gpus_per_node
@property
def world_size(self):
return self._world_size
@property
def parallel_hierarchy(self):
return self._parallel_hierarchy
@property
def tensor_parallel_size(self):
return self._tensor_parallel_size
@property
def pipeline_parallel_size(self):
return self._pipeline_parallel_size
@property
def model_parallel_size(self):
return self._tensor_parallel_size
@property
def data_parallel_size(self):
return self._data_parallel_size
@property
def device_type(self):
return self._device_type
def set_device_type(self, device_type):
assert device_type in ["cpu", "cuda"], f"not supported for {device_type}"
self._device_type = device_type
def get_layer_ranks(self, layer_idx):
layer_ranks = self._layer_ranks[layer_idx]
if self._parallel_hierarchy is None:
return layer_ranks
else:
assert len(self._parallel_hierarchy) == 2
return np.asarray(layer_ranks).reshape(self._parallel_hierarchy).tolist()
def get_layer_stage_id(self, layer_idx):
return self._layer_stage_ids[layer_idx]
def is_tensor_model_parallel(self):
return self._tensor_parallel_size > 1
def is_data_parallel(self):
return self._data_parallel_size > 1
def is_pipeline_model_parallel(self):
return self._pipeline_parallel_size > 1
def is_data_model_parallel(self):
return self.is_tensor_model_parallel() and self.is_data_parallel()
def setup_dist_util(cfg):
"""Initialize the distributed environment with configuration.
Example:
.. code-block:: python
from omegaconf import DictConfig
# set the hybrid parallel distributed environment with 2D mesh GPUs
setup_dist_util(
DictConfig(
dict(
data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
)
)
"""
global _DIST_UTIL
_DIST_UTIL = _DistributeUtil(cfg)
def get_dist_util():
"""Get distributed utils if it's been setup. Otherwise, initialize it with
single node/single gpu environment."""
global _DIST_UTIL
if _DIST_UTIL is None:
logger.warning(
"Distributed env is not set up, configure it by default (single node, single gpu)."
)
from omegaconf import DictConfig
setup_dist_util(
DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
)
)
return _DIST_UTIL
def get_layer_placement(layer_idx, device_type=None):
"""
Get ``flow.placement`` object with the initialized distributed environment
according to the ``layer_idx``.
Args:
layer_idx (int): layer index indicating the rank groups. This is very useful for pipeline
parallelism training where different layers are on different ranks.
device_type (str, optional): device type. Defaults to "cuda".
"""
dist_util = get_dist_util()
device_type = dist_util.device_type if device_type is None else device_type
if not flow.cuda.is_available() and device_type == "cuda":
device_type = "cpu"
return flow.placement(
device_type,
dist_util.get_layer_ranks(layer_idx),
)
def get_nd_sbp(sbp_list):
"""Get nd sbp signature list, which is consistent with 1D/2D mesh GPUs.
Args:
sbp_list (list): a sbp list with 2D mesh.
Returns:
A modified sbp list according to the initialized distributed environment.
"""
assert isinstance(sbp_list, list)
assert len(sbp_list) == 2
assert all(isinstance(sbp, flow.sbp.sbp) for sbp in sbp_list)
dist_util = get_dist_util()
if dist_util.is_data_model_parallel():
return sbp_list
elif dist_util.is_data_parallel():
return sbp_list[:1]
elif dist_util.is_tensor_model_parallel():
return sbp_list[1:]
else:
return [flow.sbp.broadcast]
def get_hidden_sbp():
"""Hidden states sbp."""
return get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast])
def get_data_parallel_rank():
dist_util = get_dist_util()
return (flow.env.get_rank() // dist_util.model_parallel_size) % dist_util.data_parallel_size
def get_data_parallel_size():
dist_util = get_dist_util()
return dist_util.data_parallel_size
def get_tensor_parallel_size():
dist_util = get_dist_util()
return dist_util.tensor_parallel_size
def get_pipeline_parallel_size():
dist_util = get_dist_util()
return dist_util.pipeline_parallel_size
def same_sbp(lhs_sbp, rhs_sbp):
"""Determine if two sbp signatures are the same."""
assert len(lhs_sbp) == len(rhs_sbp)
for i in range(len(lhs_sbp)):
if lhs_sbp[i] != rhs_sbp[i]:
return False
return True
def get_rank() -> int:
return flow.env.get_rank()
def get_local_rank() -> int:
return flow.env.get_local_rank()
def is_main_process() -> bool:
return get_rank() == 0
def is_last_process() -> bool:
return get_rank() == get_world_size() - 1
def get_world_size():
return flow.env.get_world_size()
def get_num_nodes():
return flow.env.get_node_size()
def set_device_type(device_type):
dist_util = get_dist_util()
dist_util.set_device_type(device_type)
def broadcast_py_object(obj, src: int = 0):
rank = flow.env.get_rank()
if src == rank:
obj_bytes = dill.dumps(obj)
return dill.loads(flow._oneflow_internal.cpu_broadcast(obj_bytes, src))
else:
return dill.loads(flow._oneflow_internal.cpu_broadcast(None, src))
def convert_to_distributed_default_setting(t):
"""
Helper function to convert all eager local tensor in :attr:`nn.Module` in the model to
global tensor with data parallelism as default.
"""
if not t.is_global:
return t.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=get_layer_placement(0),
)
else:
return t
def ttol(tensor, pure_local=False, ranks=None):
"""Global tensor to local tensor."""
if tensor.is_global:
placement = tensor.placement if not ranks else flow.placement("cuda", ranks)
if pure_local:
tensor = tensor.to_global(placement=placement).to_local()
else:
tensor = tensor.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement
).to_local()
return tensor
def tton(tensor, local_only=False, ranks=None):
"""Global tensor to numpy ndarray."""
if tensor.is_global:
tensor = ttol(tensor, local_only, ranks)
return tensor.numpy()
def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
assert device in ["cpu", "cuda"], f"not supported for device:{device}"
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
tensor = tensor.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement
)
if to_local:
tensor = ttol(tensor)
return tensor
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training.
"""
world_size = get_world_size()
if world_size == 1:
return
flow.comm.barrier()
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from typing import Callable, List, Optional
from urllib import request
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/download.py
# --------------------------------------------------------
def download(url: str, dir: str, *, filename: Optional[str] = None, progress: bool = True) -> str:
"""
Download a file from a given URL to a directory. If file exists, will not
overwrite the existing file.
Args:
url (str):
dir (str): the directory to download the file
filename (str or None): the basename to save the file.
Will use the name in the URL if not given.
progress (bool): whether to use tqdm to draw a progress bar.
Returns:
str: the path to the downloaded file or the existing one.
"""
os.makedirs(dir, exist_ok=True)
if filename is None:
filename = url.split("/")[-1]
assert len(filename), "Cannot obtain filename from url {}".format(url)
fpath = os.path.join(dir, filename)
logger = logging.getLogger(__name__)
if os.path.isfile(fpath):
logger.info("File {} exists! Skipping download.".format(filename))
return fpath
tmp = fpath + ".tmp" # download to a tmp file first, to be more atomic.
try:
logger.info("Downloading from {} ...".format(url))
if progress:
import tqdm
def hook(t: tqdm.tqdm) -> Callable[[int, int, Optional[int]], None]:
last_b: List[int] = [0]
def inner(b: int, bsize: int, tsize: Optional[int] = None) -> None:
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize) # type: ignore
last_b[0] = b
return inner
with tqdm.tqdm( # type: ignore
unit="B", unit_scale=True, miniters=1, desc=filename, leave=True
) as t:
tmp, _ = request.urlretrieve(url, filename=tmp, reporthook=hook(t))
else:
tmp, _ = request.urlretrieve(url, filename=tmp)
statinfo = os.stat(tmp)
size = statinfo.st_size
if size == 0:
raise IOError("Downloaded an empty file from {}!".format(url))
# download to tmp first and move to fpath, to make this function more
# atomic.
shutil.move(tmp, fpath)
except IOError:
logger.error("Failed to download {}".format(url))
raise
finally:
try:
os.unlink(tmp)
except IOError:
pass
logger.info("Successfully downloaded " + fpath + ". " + str(size) + " bytes.")
return fpath
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import logging
import os
import time
from collections import defaultdict
from contextlib import contextmanager
from libai.utils.file_io import PathManager
from libai.utils.history_buffer import HistoryBuffer
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/events.py
# --------------------------------------------------------
__all__ = [
"get_event_storage",
"JSONWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throw an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Example of parsing such a json file:
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 19,
"total_loss": 1.9228371381759644,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 39,
"total_loss": 1.282649278640747,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
if len(to_save):
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `tensorboardX.SummaryWriter(...)`
"""
self._window_size = window_size
from tensorboardX import SummaryWriter
self._writer = SummaryWriter(log_dir=log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# TODO: add write image
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, batch_size, max_iter):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger(__name__)
self._batch_size = batch_size
self._max_iter = max_iter
self._last_write = None
def write(self):
storage = get_event_storage()
iteration = storage.iter
consumed_samples = storage.samples
if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return
try:
data_time = storage.history("data_time").avg(20)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
eta_string = None
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
try:
lr = "{:.2e}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}{iter} {sample} {losses} {time}{data_time} {tpt} lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
iter=f"iteration: {iteration}/{self._max_iter}",
sample=f"consumed_samples: {consumed_samples}",
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(200))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} s/iter ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} s/iter".format(data_time)
if data_time is not None
else "",
tpt="total_throughput: {:.2f} samples/s".format(self._batch_size / iter_time)
if iter_time is not None
else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(HistoryBuffer)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._batch_size = 0
self._current_prefix = ""
self._vis_data = []
self._histograms = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` associated with `img_name` to be shown on
tensorboard.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (flow.Tensor or numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is
3. The image format should be RGB. The elements in img_tensor
can either have values in [0, 1] (float32) or [0, 255] (uint8).
The `img_tensor` will be visualized in tensorboard.
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
def put_scalars(self, *, smoothing_hint=True, **kwargs):
"""
Put multiple scalars from keyword arguments.
Example:
.. code-block:: python
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
"""
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def history(self, name):
"""
Returns:
HistoryBuffer: the scalar history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depending on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should either: (1) Call this function to increment storage.iter when needed.
Or (2) Set `storage.iter` to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
"""
self._iter += 1
@property
def iter(self):
"""
Returns the current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
"""
return self._iter
@iter.setter
def iter(self, val):
self._iter = int(val)
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, val):
self._samples = int(val)
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import concurrent.futures
import errno
import logging
import os
import shutil
import tempfile
import traceback
from collections import OrderedDict
from typing import IO, Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Set, Union
from urllib.parse import urlparse
import portalocker
from libai.utils.download import download
from libai.utils.non_blocking_io import NonBlockingIOManager
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/file_io.py
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/file_io.py
# --------------------------------------------------------
__all__ = ["LazyPath", "PathManager", "get_cache_dir", "file_lock"]
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $LIBAI_CACHE, if set
2) otherwise ~/.oneflow/iopath_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(os.getenv("LIBAI_CACHE", "~/.oneflow/iopath_cache"))
try:
g_pathmgr.mkdirs(cache_dir)
assert os.access(cache_dir, os.W_OK)
except (OSError, AssertionError):
tmp_dir = os.path.join(tempfile.gettempdir(), "iopath_cache")
logger = logging.getLogger(__name__)
logger.warning(f"{cache_dir} is not accessible! Using {tmp_dir} instead!")
cache_dir = tmp_dir
return cache_dir
def file_lock(path: str): # type: ignore
"""
A file lock. Once entered, it is guaranteed that no one else holds the
same lock. Others trying to enter the lock will block for 30 minutes and
raise an exception.
This is useful to make sure workers don't cache files to the same location.
Args:
path (str): a path to be locked. This function will create a lock named
`path + ".lock"`
Examples:
filename = "/path/to/file"
with file_lock(filename):
if not os.path.isfile(filename):
do_create_file()
"""
dirname = os.path.dirname(path)
try:
os.makedirs(dirname, exist_ok=True)
except OSError:
# makedir is not atomic. Exceptions can happen when multiple workers try
# to create the same dir, despite exist_ok=True.
# When this happens, we assume the dir is created and proceed to creating
# the lock. If failed to create the directory, the next line will raise
# exceptions.
pass
return portalocker.Lock(path + ".lock", timeout=3600) # type: ignore
class LazyPath(os.PathLike):
"""
A path that's lazily evaluated when it's used.
Users should be careful to not use it like a str, because
it behaves differently from a str.
Path manipulation functions in Python such as `os.path.*` all accept
PathLike objects already.
It can be materialized to a str using `os.fspath`.
"""
def __init__(self, func: Callable[[], str]) -> None:
"""
Args:
func: a function that takes no arguments and returns the
actual path as a str. It will be called at most once.
"""
self._func = func
self._value: Optional[str] = None
def _get_value(self) -> str:
if self._value is None:
self._value = self._func()
return self._value # pyre-ignore
def __fspath__(self) -> str:
return self._get_value()
# before more like a str after evaluated
def __getattr__(self, name: str): # type: ignore
if self._value is None:
raise AttributeError(f"Uninitialized LazyPath has no attribute: {name}.")
return getattr(self._value, name)
def __getitem__(self, key): # type: ignore
if self._value is None:
raise TypeError("Uninitialized LazyPath is not subscriptable.")
return self._value[key] # type: ignore
def __str__(self) -> str:
if self._value is not None:
return self._value # type: ignore
else:
return super().__str__()
class PathHandler:
"""
PathHandler is a base class that defines common I/O functionality for a URI
protocol. It routes I/O for a generic URI which may look like "protocol://*"
or a canonical filepath "/foo/bar/baz".
"""
_strict_kwargs_check = True
def __init__(
self,
async_executor: Optional[concurrent.futures.Executor] = None,
) -> None:
"""
When registering a `PathHandler`, the user can optionally pass in a
`Executor` to run the asynchronous file operations.
NOTE: For regular non-async operations of `PathManager`, there is
no need to pass `async_executor`.
Args:
async_executor (optional `Executor`): Used for async file operations.
Usage:
```
path_handler = NativePathHandler(async_executor=exe)
path_manager.register_handler(path_handler)
```
"""
self._non_blocking_io_manager = None
self._non_blocking_io_executor = async_executor
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
"""
Checks if the given arguments are empty. Throws a ValueError if strict
kwargs checking is enabled and args are non-empty. If strict kwargs
checking is disabled, only a warning is logged.
Args:
kwargs (Dict[str, Any])
"""
if self._strict_kwargs_check:
if len(kwargs) > 0:
raise ValueError("Unused arguments: {}".format(kwargs))
else:
logger = logging.getLogger(__name__)
for k, v in kwargs.items():
logger.warning("[PathManager] {}={} argument ignored".format(k, v))
def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
raise NotImplementedError()
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, the cache stays on filesystem
(under `file_io.get_cache_dir()`) and will be used by a different run.
Therefore this function is meant to be used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
force(bool): Forces a download from backend if set to True.
Returns:
local_path (str): a file path which exists on the local file system
"""
raise NotImplementedError()
def _copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
"""
Copies a local file to the specified URI.
If the URI is another local path, this should be functionally identical
to copy.
Args:
local_path (str): a file path which exists on the local file system
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing URI
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _opent(
self, path: str, mode: str = "r", buffering: int = 32, **kwargs: Any
) -> Iterable[Any]:
raise NotImplementedError()
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
raise NotImplementedError()
def _opena(
self,
path: str,
mode: str = "r",
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: int = -1,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI with asynchronous permissions.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Usage (default / without callback function):
for n in range(50):
results = run_a_large_task(n)
with path_manager.opena(uri, "w") as f:
f.write(results) # Runs in separate thread
# Main process returns immediately and continues to next iteration
path_manager.async_close()
Usage (advanced / with callback function):
# To write local and then copy to Manifold:
def cb():
path_manager.copy_from_local(
"checkpoint.pt", "manifold://path/to/bucket"
)
f = pm.opena("checkpoint.pt", "wb", callback_after_file_close=cb)
flow.save({...}, f)
f.close()
Args:
...same args as `_open`...
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
Returns:
file: a file-like object with asynchronous methods.
"""
# Restrict mode until `NonBlockingIO` has async read feature.
valid_modes = {"w", "a", "b"}
if not all(m in valid_modes for m in mode):
raise ValueError("`opena` mode must be write or append")
# TODO: Each `PathHandler` should set its own `self._buffered`
# parameter and pass that in here. Until then, we assume no
# buffering for any storage backend.
if not self._non_blocking_io_manager:
self._non_blocking_io_manager = NonBlockingIOManager(
buffered=False,
executor=self._non_blocking_io_executor,
)
try:
return self._non_blocking_io_manager.get_non_blocking_io(
path=self._get_path_with_cwd(path),
io_obj=self._open(path, mode, **kwargs),
callback_after_file_close=callback_after_file_close,
buffering=buffering,
)
except ValueError:
# When `_strict_kwargs_check = True`, then `open_callable`
# will throw a `ValueError`. This generic `_opena` function
# does not check the kwargs since it may include any `_open`
# args like `encoding`, `ttl`, `has_user_data`, etc.
logger = logging.getLogger(__name__)
logger.exception(
"An exception occurred in `NonBlockingIOManager`. This "
"is most likely due to invalid `opena` args. Make sure "
"they match the `open` args for the `PathHandler`."
)
self._async_close()
def _async_join(self, path: Optional[str] = None, **kwargs: Any) -> bool:
"""
Ensures that desired async write threads are properly joined.
Args:
path (str): Pass in a file path to wait until all asynchronous
activity for that path is complete. If no path is passed in,
then this will wait until all asynchronous jobs are complete.
Returns:
status (bool): True on success
"""
if not self._non_blocking_io_manager:
logger = logging.getLogger(__name__)
logger.warning(
"This is an async feature. No threads to join because " "`opena` was not used."
)
self._check_kwargs(kwargs)
return self._non_blocking_io_manager._join(self._get_path_with_cwd(path) if path else None)
def _async_close(self, **kwargs: Any) -> bool:
"""
Closes the thread pool used for the asynchronous operations.
Returns:
status (bool): True on success
"""
if not self._non_blocking_io_manager:
logger = logging.getLogger(__name__)
logger.warning(
"This is an async feature. No threadpool to close because " "`opena` was not used."
)
self._check_kwargs(kwargs)
return self._non_blocking_io_manager._close_thread_pool()
def _copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
raise NotImplementedError()
def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
raise NotImplementedError()
def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
raise NotImplementedError()
def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()
def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Symlink the src_path to the dst_path
Args:
src_path (str): A URI supported by this PathHandler to symlink from
dst_path (str): A URI supported by this PathHandler to symlink to
"""
raise NotImplementedError()
def _set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
"""
Set the current working directory. PathHandler classes prepend the cwd
to all URI paths that are handled.
Args:
path (str) or None: A URI supported by this PathHandler. Must be a valid
absolute path or None to set the cwd to None.
Returns:
bool: true if cwd was set without errors
"""
raise NotImplementedError()
def _get_path_with_cwd(self, path: str) -> str:
"""
Default implementation. PathHandler classes that provide a `_set_cwd`
feature should also override this `_get_path_with_cwd` method.
Args:
path (str): A URI supported by this PathHandler.
Returns:
path (str): Full path with the cwd attached.
"""
return path
class NativePathHandler(PathHandler):
"""
Handles paths that can be accessed using Python native system calls. This
handler uses `open()` and `os.*` calls on the given path.
"""
_cwd = None
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
self._check_kwargs(kwargs)
return os.fspath(path)
def _copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
self._check_kwargs(kwargs)
local_path = self._get_path_with_cwd(local_path)
dst_path = self._get_path_with_cwd(dst_path)
assert self._copy(src_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs)
def _open(
self,
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a path.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy works as follows:
* Binary files are buffered in fixed-size chunks; the size of
the buffer is chosen using a heuristic trying to determine the
underlying device’s “block size” and falling back on
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
typically be 4096 or 8192 bytes long.
encoding (Optional[str]): the name of the encoding used to decode or
encode the file. This should only be used in text mode.
errors (Optional[str]): an optional string that specifies how encoding
and decoding errors are to be handled. This cannot be used in binary
mode.
newline (Optional[str]): controls how universal newlines mode works
(it only applies to text mode). It can be None, '', '\n', '\r',
and '\r\n'.
closefd (bool): If closefd is False and a file descriptor rather than
a filename was given, the underlying file descriptor will be kept
open when the file is closed. If a filename is given closefd must
be True (the default) otherwise an error will be raised.
opener (Optional[Callable]): A custom opener can be used by passing
a callable as opener. The underlying file descriptor for the file
object is then obtained by calling opener with (file, flags).
opener must return an open file descriptor (passing os.open as opener
results in functionality similar to passing None).
See https://docs.python.org/3/library/functions.html#open for details.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
return open( # type: ignore
self._get_path_with_cwd(path),
mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
closefd=closefd,
opener=opener,
)
def _copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
if os.path.exists(dst_path) and not overwrite:
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.copyfile(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in file copy - {}".format(str(e)))
return False
def _mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
if os.path.exists(dst_path):
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.move(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in move operation - {}".format(str(e)))
return False
def _symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Creates a symlink to the src_path at the dst_path
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
logger = logging.getLogger(__name__)
if not os.path.exists(src_path):
logger.error("Source path {} does not exist".format(src_path))
return False
if os.path.exists(dst_path):
logger.error("Destination path {} already exists.".format(dst_path))
return False
try:
os.symlink(src_path, dst_path)
return True
except Exception as e:
logger.error("Error in symlink - {}".format(str(e)))
return False
def _exists(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.exists(self._get_path_with_cwd(path))
def _isfile(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isfile(self._get_path_with_cwd(path))
def _isdir(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isdir(self._get_path_with_cwd(path))
def _ls(self, path: str, **kwargs: Any) -> List[str]:
self._check_kwargs(kwargs)
return os.listdir(self._get_path_with_cwd(path))
def _mkdirs(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
# EEXIST it can still happen if multiple processes are creating the dir
if e.errno != errno.EEXIST:
raise
def _rm(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
os.remove(path)
def _set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
# Remove cwd path if None
if path is None:
self._cwd = None
return True
# Make sure path is a valid Unix path
if not os.path.exists(path):
raise ValueError(f"{path} is not a valid Unix path")
# Make sure path is an absolute path
if not os.path.isabs(path):
raise ValueError(f"{path} is not an absolute path")
self._cwd = path
return True
def _get_path_with_cwd(self, path: str) -> str:
return os.path.normpath(path if not self._cwd else os.path.join(self._cwd, path))
class HTTPURLHandler(PathHandler):
"""
Download URLs and cache them to disk.
"""
def __init__(self) -> None:
self.cache_map: Dict[str, str] = {}
def _get_supported_prefixes(self) -> List[str]:
return ["http://", "https://", "ftp://"]
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
self._check_kwargs(kwargs)
if force or path not in self.cache_map or not os.path.exists(self.cache_map[path]):
logger = logging.getLogger(__name__)
parsed_url = urlparse(path)
dirname = os.path.join(get_cache_dir(), os.path.dirname(parsed_url.path.lstrip("/")))
filename = path.split("/")[-1]
cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logger.info("Downloading {} ...".format(path))
cached = download(path, dirname, filename=filename)
logger.info("URL {} cached in {}".format(path, cached))
self.cache_map[path] = cached
return self.cache_map[path]
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a remote HTTP path. The resource is first downloaded and cached
locally.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): Not used for this PathHandler.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)
class OneDrivePathHandler(HTTPURLHandler):
"""
Map OneDrive (short) URLs to direct download links
"""
ONE_DRIVE_PREFIX = "https://1drv.ms/u/s!"
def create_one_drive_direct_download(self, one_drive_url: str) -> str:
"""
Converts a short OneDrive URI into a download link that can be used with wget
Args:
one_drive_url (str): A OneDrive URI supported by this PathHandler
Returns:
result_url (str): A direct download URI for the file
"""
data_b64 = base64.b64encode(bytes(one_drive_url, "utf-8"))
data_b64_string = data_b64.decode("utf-8").replace("/", "_").replace("+", "-").rstrip("=")
result_url = f"https://api.onedrive.com/v1.0/shares/u!{data_b64_string}/root/content"
return result_url
def _get_supported_prefixes(self) -> List[str]:
return [self.ONE_DRIVE_PREFIX]
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
logger = logging.getLogger(__name__)
direct_url = self.create_one_drive_direct_download(path)
logger.info(f"URL {path} mapped to direct download link {direct_url}")
return super()._get_local_path(os.fspath(direct_url), force=force, **kwargs)
class PathManagerBase:
"""
A class for users to open generic paths or translate generic paths to file names.
path_manager.method(path) will do the following:
1. Find a handler by checking the prefixes in `self._path_handlers`.
2. Call handler.method(path) on the handler that's found
"""
def __init__(self) -> None:
self._path_handlers: MutableMapping[str, PathHandler] = OrderedDict()
"""
Dict for path prefix to handler
"""
self._native_path_handler: PathHandler = NativePathHandler()
"""
A NativePathHandler that works on posix paths. This is used as the fallback.
"""
self._cwd: Optional[str] = None
"""
Keeps track of the single cwd (if set).
NOTE: Only one PathHandler can have a cwd set at a time.
"""
self._async_handlers: Set[PathHandler] = set()
"""
Keeps track of the PathHandler subclasses where `opena` was used so
all of the threads can be properly joined when calling
`PathManager.join`.
"""
def __get_path_handler(self, path: Union[str, os.PathLike]) -> PathHandler:
"""
Finds a PathHandler that supports the given path. Falls back to the native
PathHandler if no other handler is found.
Args:
path (str or os.PathLike): URI path to resource
Returns:
handler (PathHandler)
"""
path = os.fspath(path) # pyre-ignore
for p in self._path_handlers.keys():
if path.startswith(p):
return self._path_handlers[p]
return self._native_path_handler
def opent(
self, path: str, mode: str = "r", buffering: int = 32, **kwargs: Any
) -> Iterable[Any]:
"""
Open a tabular data source. Only reading is supported.
The opent() returns a Python iterable collection object, compared to
bytes/text data with open()
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'
buffering (int): number of rows fetched and cached
Returns:
An iterable collection object.
"""
return self.__get_path_handler(path)._opent(path, mode, buffering, **kwargs)
def open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
return self.__get_path_handler(path)._open( # type: ignore
path, mode, buffering=buffering, **kwargs
)
# NOTE: This feature is only implemented for `NativePathHandler` and can
# currently only be used in write mode.
def opena(
self,
path: str,
mode: str = "r",
buffering: int = -1,
callback_after_file_close: Optional[Callable[[None], None]] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a file with asynchronous permissions. `f.write()` calls (and
potentially `f.read()` calls in the future) will be dispatched
asynchronously such that the main program can continue running.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Usage (default / without callback function):
for n in range(50):
results = run_a_large_task(n)
# `f` is a file-like object with asynchronous methods
with path_manager.opena(uri, "w") as f:
f.write(results) # Runs in separate thread
# Main process returns immediately and continues to next iteration
path_manager.async_close()
Usage (advanced / with callback function):
# To asynchronously write to Manifold:
def cb():
path_manager.copy_from_local(
"checkpoint.pt", "manifold://path/to/bucket"
)
f = pm.opena("checkpoint.pt", "wb", callback_after_file_close=cb)
oneflow.save({...}, f)
f.close()
Args:
...
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
Returns:
file: a file-like object with asynchronous methods.
"""
non_blocking_io = self.__get_path_handler(path)._opena(
path,
mode,
buffering=buffering,
callback_after_file_close=callback_after_file_close,
**kwargs,
)
# Keep track of the path handlers where `opena` is used so that all of the
# threads can be properly joined on `PathManager.join`.
self._async_handlers.add(self.__get_path_handler(path))
return non_blocking_io
def async_join(self, *paths: str, **kwargs: Any) -> bool:
"""
Ensures that desired async write threads are properly joined.
Usage:
Wait for asynchronous methods operating on specific file paths to
complete.
async_join("path/to/file1.txt")
async_join("path/to/file2.txt", "path/to/file3.txt")
Wait for all asynchronous methods to complete.
async_join()
Args:
*paths (str): Pass in any number of file paths and `async_join` will wait
until all asynchronous activity for those paths is complete. If no
paths are passed in, then `async_join` will wait until all asynchronous
jobs are complete.
Returns:
status (bool): True on success
"""
success = True
if not paths: # Join all.
for handler in self._async_handlers:
success = handler._async_join(**kwargs) and success
else: # Join specific paths.
for path in paths:
success = self.__get_path_handler(path)._async_join(path, **kwargs) and success
return success
def async_close(self, **kwargs: Any) -> bool:
"""
`async_close()` must be called at the very end of any script that uses the
asynchronous `opena` feature. This calls `async_join()` first and then closes
the thread pool used for the asynchronous operations.
Returns:
status (bool): True on success
"""
success = self.async_join(**kwargs)
for handler in self._async_handlers:
success = handler._async_close(**kwargs) and success
self._async_handlers.clear()
return success
def copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
# Copying across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
)
return self.__get_path_handler(src_path)._copy(src_path, dst_path, overwrite, **kwargs)
def mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path supported by NativePathHandler to
a destination path.
Args:
src_path (str): A URI supported by NativePathHandler
dst_path (str): A URI supported by NativePathHandler
Returns:
status (bool): True on success
Exception:
Asserts if both the src and dest paths are not supported by
NativePathHandler.
"""
# Moving across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
), "Src and dest paths must be supported by the same path handler."
return self.__get_path_handler(src_path)._mv(src_path, dst_path, **kwargs)
def get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk.
Args:
path (str): A URI supported by this PathHandler
force(bool): Forces a download from backend if set to True.
Returns:
local_path (str): a file path which exists on the local file system
"""
path = os.fspath(path)
return self.__get_path_handler(path)._get_local_path( # type: ignore
path, force=force, **kwargs
)
def copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
"""
Copies a local file to the specified URI.
If the URI is another local path, this should be functionally identical
to copy.
Args:
local_path (str): a file path which exists on the local file system
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing URI
Returns:
status (bool): True on success
"""
assert os.path.exists(local_path)
return self.__get_path_handler(dst_path)._copy_from_local(
local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
)
def exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return self.__get_path_handler(path)._exists(path, **kwargs) # type: ignore
def isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
return self.__get_path_handler(path)._isfile(path, **kwargs) # type: ignore
def isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return self.__get_path_handler(path)._isdir(path, **kwargs) # type: ignore
def ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
return self.__get_path_handler(path)._ls(path, **kwargs)
def mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
return self.__get_path_handler(path)._mkdirs(path, **kwargs) # type: ignore
def rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
return self.__get_path_handler(path)._rm(path, **kwargs) # type: ignore
def symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""Symlink the src_path to the dst_path
Args:
src_path (str): A URI supported by this PathHandler to symlink from
dst_path (str): A URI supported by this PathHandler to symlink to
"""
# Copying across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
)
return self.__get_path_handler(src_path)._symlink(src_path, dst_path, **kwargs)
def set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
"""
Set the current working directory. PathHandler classes prepend the cwd
to all URI paths that are handled.
Args:
path (str) or None: A URI supported by this PathHandler. Must be a valid
absolute Unix path or None to set the cwd to None.
Returns:
bool: true if cwd was set without errors
"""
if path is None and self._cwd is None:
return True
if self.__get_path_handler(path or self._cwd)._set_cwd(path, **kwargs): # type: ignore
self._cwd = path
return True
return False
def register_handler(self, handler: PathHandler, allow_override: bool = False) -> None:
"""
Register a path handler associated with `handler._get_supported_prefixes`
URI prefixes.
Args:
handler (PathHandler)
allow_override (bool): allow overriding existing handler for prefix
"""
logger = logging.getLogger(__name__)
assert isinstance(handler, PathHandler), handler
# Allow override of `NativePathHandler` which is automatically
# instantiated by `PathManager`.
if isinstance(handler, NativePathHandler):
if allow_override:
self._native_path_handler = handler
else:
raise ValueError(
"`NativePathHandler` is registered by default. Use the "
"`allow_override=True` kwarg to override it."
)
return
for prefix in handler._get_supported_prefixes():
if prefix not in self._path_handlers:
self._path_handlers[prefix] = handler
continue
old_handler_type = type(self._path_handlers[prefix])
if allow_override:
# if using the global PathManager, show the warnings
global g_pathmgr
if self == g_pathmgr:
logger.warning(
f"[PathManager] Attempting to register prefix '{prefix}' from "
"the following call stack:\n" + "".join(traceback.format_stack(limit=5))
# show the most recent callstack
)
logger.warning(
f"[PathManager] Prefix '{prefix}' is already registered "
f"by {old_handler_type}. We will override the old handler. "
"To avoid such conflicts, create a project-specific PathManager "
"instead."
)
self._path_handlers[prefix] = handler
else:
raise KeyError(
f"[PathManager] Prefix '{prefix}' already registered by {old_handler_type}!"
)
# Sort path handlers in reverse order so longer prefixes take priority,
# eg: http://foo/bar before http://foo
self._path_handlers = OrderedDict(
sorted(self._path_handlers.items(), key=lambda t: t[0], reverse=True)
)
def set_strict_kwargs_checking(self, enable: bool) -> None:
"""
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
unused parameters are passed to a PathHandler function. If disabled, only
a warning is given.
With a centralized file API, there's a tradeoff of convenience and
correctness delegating arguments to the proper I/O layers. An underlying
`PathHandler` may support custom arguments which should not be statically
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
may want to expose a `cache_timeout` argument for `open()` which specifies
how old a locally cached resource can be before it's refetched from the
remote server. This argument would not make sense for a `NativePathHandler`.
If strict kwargs checking is disabled, `cache_timeout` can be passed to
`PathManager.open` which will forward the arguments to the underlying
handler. By default, checking is enabled since it is innately unsafe:
multiple `PathHandler`s could reuse arguments with different semantic
meanings or types.
Args:
enable (bool)
"""
self._native_path_handler._strict_kwargs_check = enable
for handler in self._path_handlers.values():
handler._strict_kwargs_check = enable
class PathManagerFactory:
"""
PathManagerFactory is the class responsible for creating new PathManager
instances and removing them when no longer needed.
PathManager can be instantiated directly too, but it is recommended that
you use PathManagerFactory to create them.
"""
GLOBAL_PATH_MANAGER = "global_path_manager"
pm_list = {}
@staticmethod
def get(key=GLOBAL_PATH_MANAGER) -> PathManagerBase:
"""
Get the path manager instance associated with a key.
A new instance will be created if there is no existing
instance associated with the key passed in.
Args:
key (str):
"""
if key not in PathManagerFactory.pm_list:
PathManagerFactory.pm_list[key] = PathManagerBase()
return PathManagerFactory.pm_list[key]
@staticmethod
def remove(key):
"""
Remove the path manager instance associated with a key.
Args:
key (str):
"""
if key in PathManagerFactory.pm_list:
_pm = PathManagerFactory.pm_list.pop(key) # noqa
del _pm
"""
A global instance of PathManager.
This global instance is provided for backward compatibility, but it is
recommended that clients use PathManagerFactory
"""
g_pathmgr = PathManagerFactory.get()
PathManager = PathManagerBase()
PathManager.register_handler(HTTPURLHandler())
PathManager.register_handler(OneDrivePathHandler())
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import fnmatch
import hashlib
import json
import logging
import os
import shutil
import sys
import tempfile
from functools import wraps
from io import open
from pathlib import Path
import boto3
import requests
import wget
from botocore.config import Config
from botocore.exceptions import ClientError
from tqdm import tqdm
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
cache_home = Path(os.getenv("OF_CACHE_ROOT", Path.home() / ".of_cache"))
default_cache_path = str(cache_home / "libai")
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3
/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
url_hash = hashlib.sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = hashlib.sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None):
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url, proxies=proxies)
else:
try:
response = requests.head(
url, allow_redirects=True, proxies=proxies, timeout=etag_timeout
)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
etag = None
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s",
url,
temp_file.name,
)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def get_md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
result = hash_md5.hexdigest()
return result
def download_file(out_path: str, url):
logger.info(f"downloading from {url} to {out_path}")
wget.download(url, out=out_path)
def get_data_from_cache(url, cache_dir=None, force_download=False, md5=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
filename = url.split("/")[-1]
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we have already get the file, just check the md5 if provided
if os.path.exists(cache_path) and md5 is not None:
local_file_md5 = get_md5(cache_path)
if local_file_md5 != md5:
os.unlink(cache_path)
download_file(cache_path, url)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path):
download_file(cache_path, url)
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s",
url,
temp_file.name,
)
# GET file object
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import numpy as np
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/history_buffer.py
# --------------------------------------------------------
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import functools
import logging
import os
import sys
import time
from collections import Counter
from termcolor import colored
from libai.utils.file_io import PathManager
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/logger.py
# --------------------------------------------------------
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(output=None, distributed_rank=0, *, color=True, name="libai", abbrev_name=None):
"""
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = "lb" if name == "libai" else name
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# use 1K buffer if writing to cloud storage
io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
atexit.register(io.close)
return io
"""
Below are some other convenient logging methods.
They are mainly adopted from
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
"""
def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "libai"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
_LOG_COUNTER = Counter()
_LOG_TIMER = {}
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import io
import logging
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import IO, Callable, Optional, Union
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/non_blocking_io.py
# --------------------------------------------------------
"""
This file is used for asynchronous file operations.
When `opena` is called for the first time for a specific
`PathHandler`, a `NonBlockingIOManager` is instantiated. The
manager returns a `NonBlockingIO` (or `NonBlockingBufferedIO`)
instance to the caller, and the manager maintains all of the
thread management and data management.
"""
@dataclass
class PathData:
"""
Manage the IO job queue and polling thread for a single
path. This is done to ensure that write calls to the same
path are serialized so they are written in the same order
as they were called.
On each `f.write` call where `f` is of type `NonBlockingIO`,
we send the job to the manager where it is enqueued to the
Queue. The polling Thread picks up on the job, executes it,
waits for it to finish, and then continues to poll.
"""
queue: Queue
thread: Thread
class NonBlockingIOManager:
"""
All `opena` calls pass through this class so that it can
keep track of the threads for proper cleanup at the end
of the script. Each path that is opened with `opena` is
assigned a single queue and polling thread that is kept
open until it is cleaned up by `PathManager.async_join()`.
"""
def __init__(
self,
buffered: Optional[bool] = False,
executor: Optional[concurrent.futures.Executor] = None,
) -> None:
"""
Args:
buffered (bool): IO instances will be `NonBlockingBufferedIO`
or `NonBlockingIO` based on this value. This bool is set
manually for each `PathHandler` in `_opena`.
executor: User can optionally attach a custom executor to
perform async operations through `PathHandler.__init__`.
"""
self._path_to_data = {} # Map from path to `PathData` object
self._buffered = buffered
self._IO = NonBlockingBufferedIO if self._buffered else NonBlockingIO
self._pool = executor or concurrent.futures.ThreadPoolExecutor()
def get_non_blocking_io(
self,
path: str,
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: Optional[int] = -1,
) -> Union[IO[str], IO[bytes]]:
"""
Called by `PathHandler._opena` with the path and returns a
`NonBlockingIO` instance.
Args:
path (str): A path str to operate on. This path should be
simplified to ensure that each absolute path has only a single
path str that maps onto it. For example, in `NativePathHandler`,
we can use `os.path.normpath`.
io_obj (IO): a reference to the IO object returned by the
`PathHandler._open` function.
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
"""
if not self._buffered and buffering != -1:
raise ValueError(
"NonBlockingIO is not using a buffered writer but `buffering` "
f"arg is set to non-default value of {buffering} != -1."
)
if path not in self._path_to_data:
# Initialize job queue and a polling thread
queue = Queue()
t = Thread(target=self._poll_jobs, args=(queue,))
t.start()
# Store the `PathData`
self._path_to_data[path] = PathData(queue, t)
kwargs = {} if not self._buffered else {"buffering": buffering}
return self._IO(
notify_manager=lambda io_callable: ( # Pass async jobs to manager
self._path_to_data[path].queue.put(io_callable)
),
io_obj=io_obj,
callback_after_file_close=callback_after_file_close,
**kwargs,
)
def _poll_jobs(self, queue: Optional[Callable[[], None]]) -> None:
"""
A single thread runs this loop. It waits for an IO callable to be
placed in a specific path's `Queue` where the queue contains
callable functions. It then waits for the IO job to be completed
before looping to ensure write order.
"""
while True:
# `func` is a callable function (specifically a lambda function)
# and can be any of:
# - func = file.write(b)
# - func = file.close()
# - func = None
func = queue.get() # Blocks until item read.
if func is None: # Thread join signal.
break
self._pool.submit(func).result() # Wait for job to finish.
def _join(self, path: Optional[str] = None) -> bool:
"""
Waits for write jobs for a specific path or waits for all
write jobs for the path handler if no path is provided.
Args:
path (str): Pass in a file path and will wait for the
asynchronous jobs to be completed for that file path.
If no path is passed in, then all threads operating
on all file paths will be joined.
"""
if path and path not in self._path_to_data:
raise ValueError(
f"{path} has no async IO associated with it. "
f"Make sure `opena({path})` is called first."
)
# If a `_close` call fails, we print the error and continue
# closing the rest of the IO objects.
paths_to_close = [path] if path else list(self._path_to_data.keys())
success = True
for _path in paths_to_close:
try:
path_data = self._path_to_data.pop(_path)
path_data.queue.put(None)
path_data.thread.join()
except Exception:
logger = logging.getLogger(__name__)
logger.exception(f"`NonBlockingIO` thread for {_path} failed to join.")
success = False
return success
def _close_thread_pool(self) -> bool:
"""
Closes the ThreadPool.
"""
try:
self._pool.shutdown()
except Exception:
logger = logging.getLogger(__name__)
logger.exception("`NonBlockingIO` thread pool failed to close.")
return False
return True
# NOTE: We currently only support asynchronous writes (not reads).
class NonBlockingIO(io.IOBase):
def __init__(
self,
notify_manager: Callable[[Callable[[], None]], None],
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
) -> None:
"""
Returned to the user on an `opena` call. Uses a Queue to manage the
IO jobs that need to be run to ensure order preservation and a
polling Thread that checks the Queue. Implementation for these are
lifted to `NonBlockingIOManager` since `NonBlockingIO` closes upon
leaving the context block.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Args:
notify_manager (Callable): a callback function passed in from the
`NonBlockingIOManager` so that all IO jobs can be stored in
the manager. It takes in a single argument, namely another
callable function.
Example usage:
```
notify_manager(lambda: file.write(data))
notify_manager(lambda: file.close())
```
Here, we tell `NonBlockingIOManager` to add a write callable
to the path's Queue, and then to add a close callable to the
path's Queue. The path's polling Thread then executes the write
callable, waits for it to finish, and then executes the close
callable. Using `lambda` allows us to pass callables to the
manager.
io_obj (IO): a reference to the IO object returned by the
`PathHandler._open` function.
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
"""
super().__init__()
self._notify_manager = notify_manager
self._io = io_obj
self._callback_after_file_close = callback_after_file_close
self._close_called = False
def readable(self) -> bool:
return False
def writable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def write(self, b: Union[bytes, bytearray]) -> None:
"""
Called on `f.write()`. Gives the manager the write job to call.
"""
self._notify_manager(lambda: self._io.write(b))
def seek(self, offset: int, whence: int = 0) -> int:
"""
Called on `f.seek()`.
"""
self._notify_manager(lambda: self._io.seek(offset, whence))
def tell(self) -> int:
"""
Called on `f.tell()`.
"""
raise ValueError("ioPath async writes does not support `tell` calls.")
def truncate(self, size: int = None) -> int:
"""
Called on `f.truncate()`.
"""
self._notify_manager(lambda: self._io.truncate(size))
def close(self) -> None:
"""
Called on `f.close()` or automatically by the context manager.
We add the `close` call to the file's queue to make sure that
the file is not closed before all of the write jobs are complete.
"""
# `ThreadPool` first closes the file and then executes the callback.
# We only execute the callback once even if there are multiple
# `f.close` calls.
self._notify_manager(lambda: self._io.close())
if not self._close_called and self._callback_after_file_close:
self._notify_manager(self._callback_after_file_close)
self._close_called = True
# NOTE: To use this class, use `buffered=True` in `NonBlockingIOManager`.
# NOTE: This class expects the IO mode to be buffered.
class NonBlockingBufferedIO(io.IOBase):
MAX_BUFFER_BYTES = 10 * 1024 * 1024 # 10 MiB
def __init__(
self,
notify_manager: Callable[[Callable[[], None]], None],
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: int = -1,
) -> None:
"""
Buffered version of `NonBlockingIO`. All write data is stored in an
IO buffer until the buffer is full, or `flush` or `close` is called.
Args:
Same as `NonBlockingIO` args.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
"""
super().__init__()
self._notify_manager = notify_manager
self._io = io_obj
self._callback_after_file_close = callback_after_file_close
self._buffers = [io.BytesIO()]
self._buffer_size = buffering if buffering > 0 else self.MAX_BUFFER_BYTES
self._close_called = False
def readable(self) -> bool:
return False
def writable(self) -> bool:
return True
def seekable(self) -> bool:
return False
def write(self, b: Union[bytes, bytearray]) -> None:
"""
Called on `f.write()`. Gives the manager the write job to call.
"""
buffer = self._buffers[-1]
with memoryview(b) as view:
buffer.write(view)
if buffer.tell() < self._buffer_size:
return
self.flush()
def close(self) -> None:
"""
Called on `f.close()` or automatically by the context manager.
We add the `close` call to the file's queue to make sure that
the file is not closed before all of the write jobs are complete.
"""
self.flush()
# Close the last buffer created by `flush`.
self._notify_manager(lambda: self._buffers[-1].close())
# `ThreadPool` first closes the file and then executes the callback.
self._notify_manager(lambda: self._io.close())
if not self._close_called and self._callback_after_file_close:
self._notify_manager(self._callback_after_file_close)
self._close_called = True
def flush(self) -> None:
"""
Called on `f.write()` if the buffer is filled (or overfilled). Can
also be explicitly called by user.
NOTE: Buffering is used in a strict manner. Any buffer that exceeds
`self._buffer_size` will be broken into multiple write jobs where
each has a write call with `self._buffer_size` size.
"""
buffer = self._buffers[-1]
if buffer.tell() == 0:
return
pos = 0
total_size = buffer.seek(0, io.SEEK_END)
view = buffer.getbuffer()
# Chunk the buffer in case it is larger than the buffer size.
while pos < total_size:
item = view[pos : pos + self._buffer_size]
# `item=item` is needed due to Python's late binding closures.
self._notify_manager(lambda item=item: self._io.write(item))
pos += self._buffer_size
# Close buffer immediately after being written to file and create
# a new buffer.
self._notify_manager(lambda: buffer.close())
self._buffers.append(io.BytesIO())
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from time import perf_counter
from typing import Optional
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/timer.py
# --------------------------------------------------------
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self):
self.reset()
def reset(self):
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self):
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self):
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start
__version__ = '0.2.0'
git_version = 'Unknown'
# 模型名称
modelName=Bert-Large
# 模型描述
modelDescription=Bert-Large
# 应用场景(多个标签以英文逗号分割)
appScenario=智能聊天助手
# 框架类型(多个标签以英文逗号分割)
frameType=OneFlow,Libai
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment