Commit 9fdb7dab authored by yuguo960516's avatar yuguo960516
Browse files

bloom

parents
Pipeline #150 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .build import build_optimizer, get_default_optimizer_params
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict
from typing import Any, Dict, List
import oneflow as flow
from libai.config import instantiate
from libai.layers import LayerNorm
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/solver/build.py
# --------------------------------------------------------
def build_optimizer(cfg, model):
"""
Build an optimizer from config.
"""
cfg.params.model = model
optim = instantiate(cfg)
return optim
def get_default_optimizer_params(
model,
base_lr=None,
weight_decay=None,
weight_decay_norm=None,
weight_decay_bias=None,
clip_grad_max_norm=None,
clip_grad_norm_type=None,
overrides=None,
):
"""
Get default param list for optimizer, with suport for a few types of overrides.
If no overrides are needed, it is equivalent to `model.parameters()`.
Arguments:
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay: weight decay for every group by default. Can be omitted to use the one
in optimizer.
weight_decay_norm: override weight decay for params in normalization layers
weight_decay_bias: override weight decay for bias parameters
overrides: if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
weight decay values for all module parameters named `embedding`.
For common transformer models, ``weight_decay_norm`` and ``weight_decay_bias``
are usually set to 0.
Example:
::
flow.optim.AdamW(
get_default_optimizer_params(model, weight_decay_norm=0, weight_decay_bias=0),
lr=0.01,
weight_decay=1e-4
)
"""
if overrides is None:
overrides = {}
defaults = {}
if base_lr is not None:
defaults["lr"] = base_lr
if weight_decay is not None:
defaults["weight_decay"] = weight_decay
if clip_grad_max_norm is not None and clip_grad_norm_type is not None:
defaults["clip_grad_max_norm"] = clip_grad_max_norm
defaults["clip_grad_norm_type"] = clip_grad_norm_type
bias_overrides = {}
if weight_decay_bias is not None:
bias_overrides["weight_decay"] = weight_decay_bias
if len(bias_overrides):
if "bias" in overrides:
raise ValueError("Conflicting overrides for 'bias'")
overrides["bias"] = bias_overrides
norm_module_types = (
LayerNorm,
flow.nn.BatchNorm1d,
flow.nn.BatchNorm2d,
flow.nn.BatchNorm3d,
flow.nn.GroupNorm,
flow.nn.InstanceNorm1d,
flow.nn.InstanceNorm2d,
flow.nn.InstanceNorm3d,
flow.nn.FusedBatchNorm1d,
flow.nn.FusedBatchNorm2d,
flow.nn.FusedBatchNorm3d,
)
params = []
memo = set()
for module in model.modules():
for model_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
hyperparams["weight_decay"] = weight_decay_norm
hyperparams.update(overrides.get(model_param_name, {}))
params.append({"params": [value], **hyperparams})
return reduce_param_groups(params)
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Transform parameter groups into per-parameter structure.
Later items in `params` can overwrite parameters set in previous items.
"""
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params"}
for param in item["params"]:
ret[param].update({"params": [param], **cur_params})
return list(ret.values())
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Reorganize the parameter groups and merge duplicated groups.
The number of parameter groups needs to be as small as possible in order
to efficiently use the OneFlow multi-tensor optimizer. Therefore instead
of using a parameter_group per single parameter, we reorganize the
parameter groups and merge duplicated groups. This approach speeds
up multi-tensor optimizer significantly.
"""
params = _expand_param_groups(params)
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
groups[cur_params].extend(item["params"])
ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = param_values
ret.append(cur)
return ret
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .build import build_lr_scheduler
from .lr_scheduler import (
WarmupCosineAnnealingLR,
WarmupCosineLR,
WarmupExponentialLR,
WarmupMultiStepLR,
WarmupPolynomialLR,
WarmupStepLR,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from libai.config import instantiate
def build_lr_scheduler(cfg, optimizer):
"""Build learning rate scheduler, defined by ``cfg``."""
cfg.optimizer = optimizer
scheduler = instantiate(cfg)
return scheduler
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import oneflow as flow
logger = logging.getLogger(__name__)
def WarmupCosineLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
warmup_factor: float,
warmup_iter: int,
alpha: float = 0.0,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases following
the values of the Cosine function between the initial lr set in the
optimizer to 0, after a warmup period during which it increases linearly
between 0 and the initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
alpha (float, optional): The learning rate scale factor (:math:`\\alpha`). Defaults to 0.0.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in
the first epoch and then inreases linearly to reach 1. Defaults to "linear".
"""
cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(
optimizer, decay_steps=max_iter, alpha=alpha
)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return CosineLR")
return cosine_decay_lr
elif warmup_iter > max_iter:
logger.warning("warmup iters is larger than the total training iters")
warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR(
cosine_decay_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_cosine_lr
def WarmupCosineAnnealingLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
warmup_factor: float,
warmup_iter: int,
eta_min: float = 0.0,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases following
the values of the Cosine Annealing function between the initial
lr set in the optimizer to 0, after a warmup period during which
it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in the first epoch
and then inreases linearly to reach 1. Defaults to "linear".
"""
cosine_annealing_lr = flow.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iter, eta_min=eta_min
)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return CosineAnnealingLR")
return cosine_annealing_lr
warmup_cosine_annealing_lr = flow.optim.lr_scheduler.WarmUpLR(
cosine_annealing_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_cosine_annealing_lr
def WarmupStepLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
warmup_factor: float,
warmup_iter: int,
step_size: int,
gamma: float = 0.1,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases following the values of the Step
function between the initial lr set in the optimizer to 0, after a warmup period during which
it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
step_size (int): Period of learning rate decay.
gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.1.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in the first
epoch and then inreases linearly to reach 1. Defaults to "linear".
"""
step_lr = flow.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return StepLR")
return step_lr
warmup_step_lr = flow.optim.lr_scheduler.WarmUpLR(
step_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_step_lr
def WarmupMultiStepLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
warmup_factor: float,
warmup_iter: int,
milestones: list,
gamma: float = 0.1,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases following the values of the MultiStep
function between the initial lr set in the optimizer to 0, after a warmup period during which
it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
milestones (list): List of step indices. Must be increasing.
gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.1.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in the first
epoch and then inreases linearly to reach 1. Defaults to "linear".
"""
multistep_lr = flow.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=milestones, gamma=gamma
)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return MultiStepLR")
return multistep_lr
warmup_multistep_lr = flow.optim.lr_scheduler.WarmUpLR(
multistep_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_multistep_lr
def WarmupExponentialLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
gamma: float,
warmup_factor: float,
warmup_iter: int,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases following the values of
the Exponential function between the initial lr set in the optimizer to 0,
after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
gamma (float): Multiplicative factor of learning rate decay.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in the first epoch
and then inreases linearly to reach 1. Defaults to "linear".
"""
exponential_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return ExponentialLR")
return exponential_lr
warmup_exponential_lr = flow.optim.lr_scheduler.WarmUpLR(
exponential_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_exponential_lr
def WarmupPolynomialLR(
optimizer: flow.optim.Optimizer,
max_iter: int,
warmup_factor: float,
warmup_iter: int,
end_learning_rate: float = 0.0001,
power: float = 1.0,
cycle: bool = False,
warmup_method: str = "linear",
):
"""Create a schedule with a learning rate that decreases as a polynomial decay from
the initial lr set in the optimizer to end lr defined by `lr_end`,
after a warmup period during which it increases linearly from 0 to the
initial lr set in the optimizer.
Args:
optimizer (flow.optim.Optimizer): Wrapped optimizer.
max_iter (int): Total training iters.
warmup_factor (float): The warmup factor.
warmup_iter (int): The number of warmup steps.
end_learning_rate (float, optional): The final learning rate. Defaults to 0.0001.
power (float, optional): The power of polynomial. Defaults to 1.0.
cycle (bool, optional): If cycle is True, the scheduler will decay the learning rate
every decay steps. Defaults to False.
warmup_method (str, optional): The method of warmup, you can choose "linear" or "constant".
In linear mode, the multiplication factor starts with warmup_factor in the first
epoch and then inreases linearly to reach 1. Defaults to "linear".
"""
polynomial_lr = flow.optim.lr_scheduler.PolynomialLR(
optimizer,
decay_batch=max_iter,
end_learning_rate=end_learning_rate,
power=power,
cycle=cycle,
)
if warmup_iter == 0:
logger.warning("warmup iters equals to zero, return PolynomialLR")
return polynomial_lr
warmup_polynomial_lr = flow.optim.lr_scheduler.WarmUpLR(
polynomial_lr,
warmup_factor=warmup_factor,
warmup_iters=warmup_iter,
warmup_method=warmup_method,
)
return warmup_polynomial_lr
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .build import build_tokenizer
from .tokenization_bert import BertTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_t5 import T5Tokenizer
from .tokenization_base import PreTrainedTokenizer
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from libai.config import instantiate
logger = logging.getLogger(__name__)
def build_tokenizer(cfg):
"""Initialize tokenizer."""
tokenizer = instantiate(cfg.tokenizer)
if cfg.append_eod and tokenizer.eod_token is None:
if tokenizer.eos_token is not None:
tokenizer.eod_token = tokenizer.eos_token
else:
tokenizer.eod_token = tokenizer.pad_token
return tokenizer
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""copy from HuggingFace transformer repo, to tokenize the sentence.
This class only focus on tokenization, converting token to id and their inverse operation.
It does not construct inputs using special symbols."""
import copy
import itertools
import json
import logging
import os
import unicodedata
from io import open
from typing import Dict, List, Optional, Union
import numpy as np
import oneflow as flow
from libai.utils import distributed as dist
from libai.utils.file_io import PathManager
from libai.utils.file_utils import cached_path
logger = logging.getLogger(__name__)
def _is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `char` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `char` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (
(cp >= 33 and cp <= 47)
or (cp >= 58 and cp <= 64)
or (cp >= 91 and cp <= 96)
or (cp >= 123 and cp <= 126)
):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
class PreTrainedTokenizer(object):
"""
Base class for all tokenizers.
Handle all the shared methods for tokenization and special tokens, methods
dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
This class also contains the added tokens in a unified way on top of all tokenizers, so we don't
have to handle the specific vocabulary augmentation methods of the various underlying
dictionary structures (BPE, sentencepiece...).
Class attributes (overridden by derived classes):
``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of
each vocabulary file required by the model, and as associated values, the filename for
saving the associated file (string).
``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the
``__init__`` keyword name of each vocabulary file required by the model, the low-level
being the `short-cut-names` (string) of the pretrained models with, as associated values,
the `url` (string) to the associated pretrained vocabulary file.
``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string)
of the pretrained models, and as associated values, the maximum length of the sequence
inputs of this model, or None if the model has no maximum input size.
``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names`
(string) of the pretrained models, and as associated values, a dictionnary of specific
arguments to pass to the ``__init__`` method of the tokenizer class for this pretrained
model when loading the tokenizer with the ``from_pretrained()`` method.
Args:
bos_token (:obj:`str`, `optional`): A special token representing the beginning of a
sentence.
eos_token (:obj:`str`, `optional`): A special token representing the end of a sentence.
unk_token (:obj:`str`, `optional`): A special token representing an out-of-vocabulary token.
sep_token (:obj:`str`, `optional`): A special token separating two different sentences in
the same input (used by BERT for instance).
pad_token (:obj:`str`, `optional`): A special token used to make arrays of tokens the same
size for batching purpose.
Will then be ignored by attention mechanisms or loss computation.
cls_token (:obj:`str`, `optional`): A special token representing the class of the input
(used by BERT for instance).
mask_token (:obj:`str`, `optional`): A special token representing a masked token (used by
masked-language modeling pretraining objectives, like BERT).
eod_token (:obj:`str`, `optional`): A special token representing the end of a document.
additional_special_tokens (tuple or list of :obj:`str`, `optional`):
A tuple or a list of additional special tokens.
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"eod_token",
"additional_special_tokens",
]
def __init__(self, verbose=True, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._eod_token = None
self._additional_special_tokens = []
self.verbose = verbose
# Added tokens - We store this for both slow and fast tokenizers
# until the serialization of Fast tokenizers is updated
self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
# inputs and kwargs for saving and re-loading
# (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = {}
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
for key, value in kwargs.items():
if value is None:
continue
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert all(
isinstance(t, str) for t in value
), "One of the tokens is not a string"
setattr(self, key, list(value))
elif isinstance(value, str):
setattr(self, key, value)
else:
raise TypeError(f"special token {key} has to be str but got: {type(value)}")
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
r"""
Instantiate a :class:`~PreTrainedTokenizer` (or a derived class) from a
predefined tokenizer.
Args:
pretrained_model_name_or_path(`str` or `os.PathLike`):
Can be either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache
or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer,
for instance saved using the :func:`~PreTrainedTokenizer.save_pretrained`
method, e.g., ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved
vocabulary file if and only if the tokenizer only requires a single vocabulary
file (e.g. Bert, XLNet), e.g., ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files
should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if
they exist.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint,
e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
inputs: (`optional`) positional arguments: will be passed to the
Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the
Tokenizer ``__init__`` method. Can be used to set special tokens
like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``,
``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``.
See parameters in the doc string of :class:`~PreTrainedTokenizer`
for details.
Examples:
.. code-block:: python
# We can't instantiate directly the base class `PreTrainedTokenizer` so let's
# show our examples on a derived class: BertTokenizer
# Download vocabulary from S3 and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# If vocabulary files are in a directory (e.g. tokenizer was
# saved using `save_pretrained('./test/saved_model/')`)
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
# If the tokenizer uses a single vocabulary file, you can point directly to this file
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
# You can link tokens to special vocabulary when instantiating
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
# You should be sure '<unk>' is in the vocabulary when doing that.
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
assert tokenizer.unk_token == '<unk>'
"""
return cls._from_pretrained(*inputs, **kwargs)
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models:
# Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if (
cls.pretrained_init_configuration
and pretrained_model_name_or_path in cls.pretrained_init_configuration
):
init_configuration = cls.pretrained_init_configuration[
pretrained_model_name_or_path
]
else:
# Get the vocabulary from local files
logger.info(
"Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
)
)
# Look for the tokenizer main vocabulary files
for file_id, file_name in cls.vocab_files_names.items():
if os.path.isdir(pretrained_model_name_or_path):
# If a directory is provided we look for the standard filenames
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
else:
# If a path to a file is provided we use it (will only work for non-BPE
# tokenizer using a single vocabulary file)
full_file_name = pretrained_model_name_or_path
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
# Look for the additional tokens files
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
saved_directory = os.path.dirname(saved_directory)
for file_id, file_name in additional_files_names.items():
full_file_name = os.path.join(saved_directory, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
if all(full_file_name is None for full_file_name in vocab_files.values()):
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find tokenizer files"
"at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
)
)
return None
# Get files from url, cache, or disk depending on the case
try:
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
)
except EnvironmentError as e:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
str(vocab_files.keys()),
)
)
raise e
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
logger.info(
"loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id]
)
)
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
if tokenizer_config_file is not None:
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
init_kwargs = init_configuration
# Update with newly provided kwargs
init_kwargs.update(kwargs)
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items():
if key not in init_kwargs:
init_kwargs[key] = value
# Instantiate tokenizer.
tokenizer = cls(*init_inputs, **init_kwargs)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs
tokenizer.init_kwargs = init_kwargs
# Add supplementary tokens.
special_tokens = tokenizer.all_special_tokens
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
# Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
for token, index in added_tok_encoder_sorted:
assert index == len(tokenizer), (
f"Non-consecutive added token '{token}' found. "
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
)
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
# Check all our special tokens are registered as "no split" token
# (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning(
"Special tokens have been added in the vocabulary,"
"make sure the associated word embedding are fine-tuned or trained."
)
return tokenizer
def save_pretrained(self, save_directory):
"""
Save the tokenizer vocabulary files together with:
- added tokens,
- special-tokens-to-class-attributes-mapping,
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This won't save modifications other than ``added tokens`` and ``special token mapping``,
you may have applied to the tokenizer after the instantiation (e.g. modifying
tokenizer.do_lower_case after creation).
This method make sure the full tokenizer can then be re-loaded using the
:func:`~PreTrainedTokenizer.from_pretrained` class method.
"""
if not PathManager.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory))
return
PathManager.mkdirs(save_directory)
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs)
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory)
return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings.
Please use :func:`~PreTrainedTokenizer.save_pretrained` to save the
full Tokenizer state if you want to reload it using the
:func:`~PreTrainedTokenizer.from_pretrained` class method.
"""
raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Size of the base vocabulary (without the added tokens)."""
raise NotImplementedError
def padded_vocab_size(self, multiple=1) -> int:
"""Padded the vocabulary with dummy tokens and return the new size."""
vocab_size = len(self)
while vocab_size % multiple != 0:
vocab_size += 1
return vocab_size
def __len__(self):
"""Size of the full vocabulary with the added tokens."""
return self.vocab_size + len(self.added_tokens_encoder)
def get_vocab(self) -> Dict[str, int]:
"""
Returns the vocabulary as a dictionary of token to index.
:obj:`tokenizer.get_vocab()[token]` is equivalent to
:obj:`tokenizer.convert_tokens_to_ids(token)`
when :obj:`token` is in the vocab.
Returns:
:obj:`Dict[str, int]`: The vocabulary.
"""
raise NotImplementedError
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index.
Returns:
:obj:`Dict[str, int]`: The added tokens.
"""
return self.added_tokens_encoder
def add_tokens(self, new_tokens: Union[str, List[str]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from the length of
the current vocabulary.
.. Note::
When adding new tokens to the vocabulary, you should make sure to also resize
the token embedding matrix of the model so that its embedding matrix matches
the tokenizer.
In order to do that, please use the
:meth:`~PreTrainedModel.resize_token_embeddings` method.
Args:
new_tokens (:obj:`str`, or a list of `str`):
Tokens are only added if they are not already in the vocabulary.
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Can be used to specify if the token is a special token. This mostly change
the normalization behavior
(special tokens like CLS or [MASK] are usually not lower-cased for instance).
Returns:
:obj:`int`: Number of tokens added to the vocabulary.
Examples:
.. code-block:: python
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new
# vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
if not new_tokens:
return 0
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
tokens_to_add = []
for token in new_tokens:
if not isinstance(token, str):
raise TypeError(f"Token {token} is not a string but a {type(token)}.")
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
if special_tokens:
self.unique_no_split_tokens = sorted(
set(self.unique_no_split_tokens).union(set(new_tokens))
)
else:
self.unique_no_split_tokens = sorted(
set(self.unique_no_split_tokens).union(set(tokens_to_add))
)
return len(tokens_to_add)
def sanitize_special_tokens(self) -> int:
"""
Make sure that all the special tokens attributes of the tokenizer
(:obj:`tokenizer.mask_token`, :obj:`tokenizer.cls_token`, etc.)
are in the vocabulary.
Add the missing ones to the vocabulary if needed.
Return:
:obj:`int`: The number of tokens added in the vocaulary during the operation.
"""
return self.add_tokens(self.all_special_tokens, special_tokens=True)
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to
class attributes. If special tokens are NOT in the vocabulary, they are added to it
(indexed starting from the last index of the current vocabulary).
.. Note::
When adding new tokens to the vocabulary, you should make sure to also resize the
token embedding matrix of the model so that its embedding matrix matches the tokenizer.
In order to do that, please use the
:meth:`~PreTrainedModel.resize_token_embeddings` method.
Using :obj:`add_special_tokens` will ensure your special tokens can be used in several ways:
- Special tokens are carefully handled by the tokenizer (they are never split).
- You can easily refer to special tokens using tokenizer class attributes like
:obj:`tokenizer.cls_token`. This makes it easy to develop model-agnostic training and
fine-tuning scripts.
When possible, special tokens are already registered for provided pretrained models
(for instance :class:`~BertTokenizer` :obj:`cls_token` is already registered
to be :obj`'[CLS]'` and XLM's one is also registered to be :obj:`'</s>'`).
Args:
special_tokens_dict (dictionary `str` to `str`):
Keys should be in the list of predefined special attributes: [``bos_token``,
``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``,
``cls_token``, ``mask_token``,
``additional_special_tokens``].
Tokens are only added if they are not already in the vocabulary (tested by
checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
:obj:`int`: Number of tokens added to the vocabulary.
Examples:
.. code-block:: python
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary,
# i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == '<CLS>'
"""
if not special_tokens_dict:
return 0
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
if self.verbose:
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(
isinstance(t, str) for t in value
), f"Tokens {value} for key {key} should all be a string"
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(value, str), f"Token {value} for key {key} should be a string"
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
def tokenize(self, text: str, **kwargs) -> List[str]:
"""
Converts a string in a sequence of tokens, using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
(BPE/SentencePieces/WordPieces). Take care of added tokens.
Args:
text (:obj:`str`):
The sequence to be encoded.
**kwargs (additional keyword arguments):
Passed along to the model-specific ``prepare_for_tokenization``
preprocessing method.
Returns:
:obj:`List[str]`: The list of tokens.
"""
def split_on_token(tok, text):
result = []
split_text = text.split(tok)
for i, sub_text in enumerate(split_text):
sub_text = sub_text.strip()
if i == 0 and not sub_text:
result += [tok]
elif i == len(split_text) - 1:
if sub_text:
result += [sub_text]
else:
pass
else:
if sub_text:
result += [sub_text]
result += [tok]
return result
def split_on_tokens(tok_list, text):
if not text:
return []
if not tok_list:
return self._tokenize(text, **kwargs)
tokenized_text = []
text_list = [text]
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.unique_no_split_tokens:
tokenized_text += split_on_token(tok, sub_text)
else:
tokenized_text += [sub_text]
text_list = tokenized_text
return list(
itertools.chain.from_iterable(
(
self._tokenize(token)
if token not in self.unique_no_split_tokens
else [token]
for token in tokenized_text
)
)
)
no_split_token = self.unique_no_split_tokens
tokenized_text = split_on_tokens(no_split_token, text)
return tokenized_text
def _tokenize(self, text, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for
word-based vocabulary or sub-words for sub-word-based vocabularies
(BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
"""Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens)
if len(tokens) > 0 and isinstance(tokens[0], list):
ids = []
for ts in tokens:
ids_x = []
for token in ts:
ids_x.append(self._convert_token_to_id_with_added_voc(token))
ids.append(ids_x)
return ids
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids
def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, **kwargs):
if return_tensors is None:
return_token_ids = token_ids
elif return_tensors == "of":
if not is_global:
return_token_ids = flow.tensor(token_ids, dtype=flow.long)
elif is_global:
sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]))
placement = kwargs.get(
"placement", flow.placement("cuda", list(range(dist.get_world_size())))
)
return_token_ids = flow.tensor(
token_ids, sbp=sbp, placement=placement, dtype=flow.long
)
elif return_tensors == "np":
return_token_ids = np.array(token_ids, dtype=np.int64)
return return_token_ids
def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text, return_tensors=None, is_global=False, **kwargs):
if isinstance(text, str):
tokens = self.tokenize(text)
token_ids = self.convert_tokens_to_ids(tokens)
token_ids = self.build_inputs_with_special_tokens(token_ids)
token_ids = self.convert_to_tensors(
token_ids, return_tensors=return_tensors, is_global=is_global, **kwargs
)
return token_ids
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
tokens = [self.tokenize(t) for t in text]
token_ids_list = self.convert_tokens_to_ids(tokens)
token_ids_list = [
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
token_ids_list = self.convert_to_tensors(
token_ids_list, return_tensors=return_tensors, is_global=is_global, **kwargs
)
return token_ids_list
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or "
"a list/tuple of integers."
)
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
"""
Converts a single index or a sequence of indices in a token or a sequence of tokens,
using the vocabulary and added tokens.
Args:
ids (:obj:`int` or :obj:`List[int]`):
The token id (or token ids) to convert to tokens.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
Returns:
:obj:`str` or :obj:`List[str]`: The decoded token(s).
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
if skip_special_tokens and index in self.all_special_ids:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
else:
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index: int) -> str:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens to a single string. The most simple way to do it is
``" ".join(tokens)`` but we often want to remove sub-word tokenization artifacts
at the same time.
Args:
tokens (:obj:`List[str]`): The token to join in a string.
Returns:
:obj:`str`: The joined tokens.
"""
return " ".join(tokens)
def decode(
self,
token_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens: bool = True,
):
"""
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or
`encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
# Convert inputs to python lists
if isinstance(token_ids, flow.Tensor):
token_ids = token_ids.tolist()
filtered_tokens = self.convert_ids_to_tokens(
token_ids, skip_special_tokens=skip_special_tokens
)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
if spaces_between_special_tokens:
text = " ".join(sub_texts)
else:
text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
@property
def bos_token(self) -> str:
"""
:obj:`str`: Beginning of sentence token. Log an error if used while not having been set.
"""
if self._bos_token is None and self.verbose:
logger.error("Using bos_token, but it is not set yet.")
return None
return str(self._bos_token)
@property
def eos_token(self) -> str:
"""
:obj:`str`: End of sentence token. Log an error if used while not having been set.
"""
if self._eos_token is None and self.verbose:
logger.error("Using eos_token, but it is not set yet.")
return None
return str(self._eos_token)
@property
def unk_token(self) -> str:
"""
:obj:`str`: Unknown token. Log an error if used while not having been set.
"""
if self._unk_token is None and self.verbose:
logger.error("Using unk_token, but it is not set yet.")
return None
return str(self._unk_token)
@property
def sep_token(self) -> str:
"""
:obj:`str`: Separation token, to separate context and query in an input sequence.
Log an error if used while not having been set.
"""
if self._sep_token is None and self.verbose:
logger.error("Using sep_token, but it is not set yet.")
return None
return str(self._sep_token)
@property
def pad_token(self) -> str:
"""
:obj:`str`: Padding token. Log an error if used while not having been set.
"""
if self._pad_token is None and self.verbose:
logger.error("Using pad_token, but it is not set yet.")
return None
return str(self._pad_token)
@property
def cls_token(self) -> str:
"""
:obj:`str`: Classification token, to extract a summary of an input sequence leveraging
self-attention along the full depth of the model.
Log an error if used while not having been set.
"""
if self._cls_token is None and self.verbose:
logger.error("Using cls_token, but it is not set yet.")
return None
return str(self._cls_token)
@property
def mask_token(self) -> str:
"""
:obj:`str`: Mask token, to use when training a model with masked-language modeling.
Log an error if used while not having been set.
"""
if self._mask_token is None and self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@property
def eod_token(self) -> str:
"""
:obj:`str`: End of document token. Log an error if used while not having been set.
"""
if self._eod_token is None and self.verbose:
logger.error("Using eod_token, but it is not set yet.")
return None
return str(self._eod_token)
@property
def start_token(self) -> str:
"""
:obj:`str`: Start token of sentence. Common name for bos_token and cls_token.
"""
if self._bos_token is not None and self._cls_token is not None:
if self._bos_token == self._cls_token:
return str(self._bos_token)
else:
logger.error("Conflict between bos_token and cls_token.")
return None
elif self._bos_token is None and self._cls_token is not None:
return str(self._cls_token)
elif self._bos_token is not None and self._cls_token is None:
return str(self._bos_token)
else:
logger.error("Using start_token, but it is not set yet.")
return None
@property
def end_token(self) -> str:
"""
:obj:`str`: End token of sentence. Common name for eos_token and sep_token.
Note: eod_token is not considered, because it is often same with eos_token.
"""
if self._eos_token is not None and self._sep_token is not None:
if self._eos_token == self._sep_token:
return str(self._eos_token)
else:
logger.error("Conflict between eos_token and _sep_token.")
return None
elif self._eos_token is None and self._sep_token is not None:
return str(self._sep_token)
elif self._eos_token is not None and self._sep_token is None:
return str(self._eos_token)
else:
logger.error("Using end_token, but it is not set yet.")
return None
@property
def additional_special_tokens(self) -> List[str]:
"""
:obj:`List[str]`: All the additional special tokens you may want to use.
Log an error if used while not having been set.
"""
if self._additional_special_tokens is None and self.verbose:
logger.error("Using additional_special_tokens, but it is not set yet.")
return None
return [str(tok) for tok in self._additional_special_tokens]
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
@eod_token.setter
def eod_token(self, value):
self._eod_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
@property
def bos_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the beginning of sentence token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the end of sentence token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the unknown token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the separation token in the vocabulary,
to separate context and query in an input sequence.
Returns :obj:`None` if the token has not been set.
"""
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the padding token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token)
@property
def cls_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the classification token in the vocabulary,
to extract a summary of an input sequence leveraging self-attention
along the full depth of the model.
Returns :obj:`None` if the token has not been set.
"""
if self._cls_token is None:
return None
return self.convert_tokens_to_ids(self.cls_token)
@property
def mask_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the mask token in the vocabulary, used when training a
model with masked-language modeling. Returns :obj:`None` if the token has not been set.
"""
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token)
@property
def eod_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the end of document token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
if self._eod_token is None:
return None
return self.convert_tokens_to_ids(self.eod_token)
@property
def start_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the start token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
start_token = self.start_token
if start_token is None:
return None
else:
return self.convert_tokens_to_ids(start_token)
@property
def end_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the end token in the vocabulary.
Returns :obj:`None` if the token has not been set.
"""
end_token = self.end_token
if end_token is None:
return None
else:
return self.convert_tokens_to_ids(end_token)
@property
def additional_special_tokens_ids(self) -> List[int]:
"""
:obj:`List[int]`: Ids of all the additional special tokens in the vocabulary.
Log an error if used while not having been set.
"""
return self.convert_tokens_to_ids(self.additional_special_tokens)
@property
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
"""
A dictionary mapping special token class attributes
(:obj:`cls_token`, :obj:`unk_token`, etc.) to their values
(:obj:`'<unk>'`, :obj:`'<cls>'`, etc.).
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self) -> List[str]:
"""
:obj:`List[str]`: All the special tokens
(:obj:`'<unk>'`, :obj:`'<cls>'`, etc.) mapped to class attributes.
"""
all_toks = []
set_attr = self.special_tokens_map
for attr_value in set_attr.values():
all_toks = all_toks + (
list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]
)
all_toks = list(set(all_toks))
return all_toks
@property
def all_special_ids(self) -> List[int]:
"""
:obj:`List[int]`: List the ids of the special tokens
(:obj:`'<unk>'`, :obj:`'<cls>'`, etc.) mapped to class attributes.
"""
all_toks = self.all_special_tokens
all_ids = list(self.convert_tokens_to_ids(all_toks))
return all_ids
@staticmethod
def clean_up_tokenization(out_string):
"""Clean up a list of simple English tokenization artifacts like spaces before
punctuations and abbreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" do not", " don't")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for bert (wordpieces)."""
import collections
import logging
import os
import re
import unicodedata
from io import open
from typing import List, Optional
from .tokenization_base import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-chinese": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def _is_chinese_substr(char):
return re.findall("##[\u4E00-\u9FA5]", char)
class BertTokenizer(PreTrainedTokenizer):
"""
Construct a BERT tokenizer. Based on WordPiece.
Args:
vocab_file (:obj:`str`):
Path to a one-wordpiece-per-line vocabulary file.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to lower case the input.
Only has an effect when do_basic_tokenize=True.
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to do basic tokenization before wordpiece.
never_split (:obj:`Iterable`, `optional`):
List of tokens which will never be split during tokenization.
Only has an effect when do_basic_tokenize=True.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese,
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328.
do_chinese_wwm (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to do whole word masking for Chinese.
Chinese sentence will be segmented by a third-party tool first.
Each substr will be added '##' prefix and its index will be calucated by
id(##A) = id(A) + vocab_size.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
do_chinese_wwm=False,
add_bos_token=False,
**kwargs,
):
super(BertTokenizer, self).__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the "
"vocabulary from a Google pretrained model use "
"`tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
vocab_file
)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]
)
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
if do_chinese_wwm:
self.basic_tokenizer = BasicTokenizerWithChineseWWM(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
)
else:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.add_bos_token = add_bos_token
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.
For Chinese substr, id = vocab_size + id(substr.remove(##)).
"""
index = self.vocab.get(token, self.vocab.get(self.unk_token))
return index
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.
For Chinese substr, id = vocab_size + id(substr.remove(##)).
"""
token = self.ids_to_tokens.get(index, self.unk_token)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) to a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Add special tokens to a sequence or a pair of sequence.
BERT format sentence input:
- single sequence: [CLS] tokens_a [SEP]
- pair of sequences: [CLS] tokens_a [SEP] tokens_b [SEP]
Args:
token_ids_0 (List[int]): The token ids of sentence 0.
token_ids_1 (List[int], optional): The token ids of sentence 1. Defaults to None.
Returns:
:obj:`List[str]`: The sequence after adding special toekens.
"""
if self.add_bos_token:
cls = [self.cls_token_id]
sep = [self.sep_token_id]
else:
cls = []
sep = []
if token_ids_1 is None:
return cls + token_ids_0 + sep
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic
tokenization (punctuation splitting, lower casing, etc.).
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
"""Constructs a BasicTokenizer.
Args:
**do_lower_case**: Whether to lower case the input.
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level
(see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
**tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
"""
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text.
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level
(see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class BasicTokenizerWithChineseWWM(BasicTokenizer):
"""Pre-segmentation for Chinese sentences, which will be used in whole word mask."""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
super(BasicTokenizerWithChineseWWM, self).__init__(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
)
try:
import jieba
self.pre_tokenizer = lambda x: jieba.lcut(x, HMM=False)
except ImportError:
raise (ImportError("Chinese whole word mask need jieba"))
def _tokenize_chinese_chars(self, text):
"""For Chinese pieces, uses jieba to segment the words and
adds whitespace around CJK character."""
output = []
piece = ""
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
piece += char
else:
chinese_words = self.pre_tokenizer(piece)
for word in chinese_words:
output.append(" ")
output.append(word)
output.append(" ")
output.append(char)
piece = ""
chinese_words = self.pre_tokenizer(piece)
for word in chinese_words:
output.append(" ")
output.append(word)
output.append(" ")
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
input = "有没有"
output = ["有", "##没", "##有"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr.startswith("##"):
if _is_chinese_substr(substr):
if substr[2:] in self.vocab: # for Chinese substr
cur_substr = substr
break
else:
if substr in self.vocab: # for English substr
cur_substr = substr
break
else:
if (
substr in self.vocab
): # non-substr, maybe character or whole Chinese word
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT (BPE)."""
import json
import logging
import os
from functools import lru_cache
from io import open
from typing import List, Optional
import regex as re
from .tokenization_base import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {"gpt2": "https://huggingface.co/gpt2/resolve/main/vocab.json"},
"merges_file": {"gpt2": "https://huggingface.co/gpt2/resolve/main/merges.txt"},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"gpt2": 1024,
}
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping
to whitespace/control characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode
characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token
dataset you end up needing around 5K for decent coverage. This is a significant percentage
of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8
bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(PreTrainedTokenizer):
"""
Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`__ for more information.
unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The unknown token. A token that is not in the vocabulary cannot be
converted to an ID and is set to be this token instead.
bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The beginning of sequence token.
eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
The end of sequence token.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_bos_token=False,
**kwargs,
):
super(GPT2Tokenizer, self).__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding="utf-8").read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self.add_bos_token = add_bos_token
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # noqa
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
# Maps all our bytes to unicode strings, avoiding control tokens
# of the BPE (spaces in our case)
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) to a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Add special tokens to a sequence or a pair of sequence.
GPT2 format sentence input:
- single sequence: <|endoftext|> tokens_a
- pair of sequences: <|endoftext|> tokens_a <|endoftext|> tokens_b
Args:
token_ids_0 (List[int]): The token ids of sentence 0.
token_ids_1 (List[int], optional): The token ids of sentence 1. Defaults to None.
Returns:
:obj:`List[str]`: The sequence after adding special toekens.
"""
if self.add_bos_token:
bos = [self.bos_token_id]
else:
bos = []
if token_ids_1 is None:
return bos + token_ids_0
return bos + token_ids_0 + bos + token_ids_1
def save_vocabulary(self, save_directory, filename_prefix=None):
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
)
merge_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"],
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return (vocab_file, merge_file)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
import json
import logging
import os
from functools import lru_cache
from typing import List, Optional, Tuple
import regex as re
from .tokenization_base import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/vocab.json",
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/vocab.json",
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json",
},
"merges_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt",
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/merges.txt",
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"roberta-base": 512,
"roberta-large": 512,
"roberta-large-mnli": 512,
}
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to
whitespace/control characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode
characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token
dataset you end up needing around 5K for decent coverage. This is a significant percentage of
your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class RobertaTokenizer(PreTrainedTokenizer):
"""Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer,
using byte-level Byte-Pair-Encoding.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`__ for more information.
bos_token (:obj:`str`, `optional`, defaults to `<s>`):
The beginning of sequence token.
eos_token (:obj:`str`, `optional`, defaults to `</s>`):
The end of sequence token.
cls_token (:obj:`str`, `optional`, defaults to `<s>`):
The first token of the sequence when built with special tokens.
unk_token (:obj:`str`, `optional`, defaults to `<unk>`):
The unknown token. A token that is not in the vocabulary cannot be
converted to an ID and is set to be this token instead.
pad_token (:obj:`str`, `optional`, defaults to `<pad>`): A special token
used to make arrays of tokens the same size for batching purpose.
Will then be ignored by attention mechanisms or loss computation.
mask_token (:obj:`str`, `optional`, defaults to `<mask>`): A special token
representing a masked token (used by masked-language modeling pretraining
objectives, like BERT).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_bos_token=False,
**kwargs,
):
super(RobertaTokenizer, self).__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
with open(vocab_file, encoding="utf-8") as file:
self.encoder = json.load(file)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as file:
bpe_merges = file.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self.add_bos_token = add_bos_token
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Add special tokens to a sequence or a pair of sequence.
RoBERTa format sentence input:
- single sequence: [CLS] tokens_a [SEP]
- pair of sequences: [CLS] tokens_a [SEP] tokens_b [SEP]
Args:
token_ids_0 (List[int]): The token ids of sentence 0.
token_ids_1 (List[int], optional): The token ids of sentence 1. Defaults to None.
Returns:
:obj:`List[str]`: The sequence after adding special toekens.
"""
if self.add_bos_token:
cls = [self.cls_token_id]
sep = [self.sep_token_id]
else:
cls = []
sep = []
if token_ids_1 is None:
return cls + token_ids_0 + sep
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
)
merge_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"],
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair
classification task. RoBERTa does not make use of token type ids, therefore
a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
# coding=utf-8
# Copyright 2018 T5 Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization class for Google T5 (sentence piece)."""
import logging
import os
import warnings
from shutil import copyfile
from typing import List, Optional
import regex as re
import sentencepiece as spm
from .tokenization_base import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {"t5-base": "https://huggingface.co/t5-base/resolve/main/spiece.model"}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-base": 512,
}
class T5Tokenizer(PreTrainedTokenizer):
"""
Construct a T5 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sequence token.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot
be converted to an ID and is set to be this token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
extra_ids (:obj:`int`, `optional`, defaults to 100):
Add a number of extra ids added to the end of the vocabulary for use
as sentinels. These tokens are accessible as "<extra_id_{%d}>" where
"{%d}" is a number between 0 and extra_ids-1. Extra tokens are indexed
from the end of the vocabulary up to beginning ("<extra_id_0>" is the
last token in the vocabulary like in T5 preprocessing see `here
<https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117>`__).
additional_special_tokens (:obj:`List[str]`, `optional`):
Additional special tokens used by the tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
add_bos_token=False,
**kwargs,
):
# Add extra_ids to the special token list
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
extra_tokens = len(
set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))
)
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens "
f"({additional_special_tokens}) are privided to T5Tokenizer. "
"In this case the additional_special_tokens must include the extra_ids tokens"
)
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self.add_bos_token = add_bos_token
@property
def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Tokenize a string."""
pieces = self.sp_model.encode(text, out_type=str)
return pieces
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"):
match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1))
return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
else:
token = f"<extra_id_{self.vocab_size - 1 - index}>"
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) to a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " "
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()
def _add_eos_if_not_present(self, token_ids):
if not self.add_bos_token:
return token_ids
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn("This sequence already has {self.eos_token}.")
return token_ids
else:
return token_ids + [self.eos_token_id]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Add special tokens to a sequence or a pair of sequence.
T5 format sentence input:
- single sequence: tokens_a </s>
- pair of sequences: tokens_a </s> tokens_b </s>
Args:
token_ids_0 (List[int]): The token ids of sentence 0.
token_ids_1 (List[int], optional): The token ids of sentence 1. Defaults to None.
Returns:
:obj:`List[str]`: The sequence after adding special toekens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return token_ids_0 + token_ids_1
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save the tokenizer vocabulary to a directory or file."""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}")
return (out_vocab_file,)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import shutil
from collections import defaultdict
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple
import numpy as np
import oneflow as flow
from oneflow import nn
from termcolor import colored
import libai.utils.distributed as dist
from libai.utils.file_io import HTTPURLHandler, PathManagerBase
class _IncompatibleKeys(
NamedTuple(
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
("incorrect_shapes", List[Tuple]),
],
)
):
pass
class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
# NOTE: only support data_parallel for saving model
# TODO: save model: support model_parallel and pipeline parallel
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
# Default PathManager, support HTTP URLs
# A user may want to use a different project-specific PathManagerBase'
self.path_manager: PathManagerBase = PathManagerBase()
self.path_manager.register_handler(HTTPURLHandler())
def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = name
save_dir = os.path.join(self.save_dir, basename)
assert os.path.basename(save_dir) == basename, basename
if not self.path_manager.exists(save_dir):
self.path_manager.mkdirs(save_dir)
self.logger.info("Saving checkpoint to {}".format(save_dir))
for save_name in data:
if save_name == "iteration":
continue
save_file = os.path.join(save_dir, save_name)
# If directory existing, remove it for saving
if self.path_manager.exists(save_file):
self.path_manager.mkdirs(save_file)
flow.save(data[save_name], save_file, global_dst_rank=0)
if basename != "model_best":
self.tag_last_checkpoint(basename)
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info("No checkpoint found. Training model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(path))
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if incompatible is not None: # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint: # pyre-ignore
self.logger.info("Loading {} from {}".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return self.path_manager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
# load checkpoint file in rank0
if flow.env.get_rank() == 0:
with open(save_file, "r") as f:
last_saved = f.read().strip()
else:
last_saved = None
# broadcast checkpoint file to other ranks
last_saved = dist.broadcast_py_object(last_saved, src=0)
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint (if exists). Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with self.path_manager.open(save_file, "w") as f:
f.write(last_filename_basename) # pyre-ignore
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to flow.Tensor or numpy arrays.
"""
data = {}
keys = self.path_manager.ls(f)
# broadcast checkpointer keys to other ranks
keys = dist.broadcast_py_object(keys, src=0)
for key in keys:
data[key] = flow.load(os.path.join(f, key), global_src_rank=0)
try:
data["iter"] = int(f.split("_")[-1])
except: # noqa
self.logger.info(f"iter info in {f} not found, set iter to 0")
data["iter"] = 0
return data
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
missing_keys = _filter_reused_missing_keys(self.model, incompatible.missing_keys)
if missing_keys:
self.logger.info(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.info(get_unexpected_parameters_message(incompatible.unexpected_keys))
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to flow tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by oneflow's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(v, flow.Tensor):
raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
# If it's local tensor, convert it to global tensor.
if not v.is_global:
if k in self.model.state_dict():
model_v = self.model.state_dict()[k]
state_dict[k] = v.to_global(sbp=model_v.sbp, placement=model_v.placement)
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(
self,
checkpointer: Checkpointer,
period: int,
max_iter: Optional[int] = None,
max_to_keep: Optional[int] = None,
file_prefix: str = "model",
):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_epoch (int): maximum number of epochs. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
if max_to_keep is not None:
assert max_to_keep > 0
self.max_to_keep = max_to_keep
self.recent_checkpoints: List[str] = []
self.file_prefix = file_prefix
self.path_manager: PathManagerBase = checkpointer.path_manager
def step(self, iteration: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current epoch, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"{}_{:07d}".format(self.file_prefix, iteration), **additional_state
)
if self.max_to_keep is not None:
self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file())
if len(self.recent_checkpoints) > self.max_to_keep:
file_to_delete = self.recent_checkpoints.pop(0)
if (
dist.is_main_process()
and self.path_manager.exists(file_to_delete)
and not file_to_delete.endswith(
"{}_{:07d}".format(self.file_prefix, iteration)
)
):
if not self.path_manager.isfile(file_to_delete):
shutil.rmtree(file_to_delete)
else:
self.path_manager.rm(file_to_delete)
if self.max_iter is not None:
if iteration >= self.max_iter - 1:
self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False) # pyre-ignore
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items())
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items())
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix) :]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1 :]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(model: nn.Module, prefix: str = "") -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import dill
import numpy as np
import oneflow as flow
from omegaconf import OmegaConf
from libai.config import try_get_key
logger = logging.getLogger(__name__)
_DIST_UTIL = None
def _merge_devices(devices):
num_gpus_per_node = get_world_size() // get_num_nodes()
node_devices = [node_id * num_gpus_per_node + device_id for node_id, device_id in devices]
return node_devices
class _DistributeUtil(object):
def __init__(self, cfg):
self._init_distributed_env(cfg)
self._init_parallel_size(cfg)
self._init_placement_group(cfg)
self._init_parallel_hierarchy()
def _init_distributed_env(self, cfg):
"""Initialize the distributed environment."""
num_nodes = get_num_nodes()
num_gpus_per_node = get_world_size() // num_nodes
if try_get_key(cfg, "num_gpus_per_node", default=num_gpus_per_node) != num_gpus_per_node:
# This means key(num_gpus_per_node) saved in config is not equal
# to environment variable.
# Give user a warning about inconsistent reproduce environment.
logger.warning(
"'train.dist.num_gpus_per_node' are not equal to environment variable. "
f"{cfg.num_gpus_per_node} != {num_gpus_per_node}"
)
if try_get_key(cfg, "num_nodes", default=num_nodes) != num_nodes:
logger.warning(
"'train.dist.num_nodes' are not equal to"
f"environment variable. {cfg.num_nodes} != {num_nodes}"
)
# Set the actual value to config
cfg.num_nodes = num_nodes
cfg.num_gpus_per_node = num_gpus_per_node
self._num_nodes = num_nodes
self._num_gpus_per_node = num_gpus_per_node
self._world_size = num_gpus_per_node * num_nodes
# Add set device type
self._device_type = try_get_key(cfg, "device_type", default="cuda")
def _init_parallel_size(self, cfg):
# tensor parallel size
self._tensor_parallel_size = min(cfg.tensor_parallel_size, self.world_size)
assert self.world_size % self._tensor_parallel_size == 0, (
f"world size ({self.world_size}) is not divisible by"
f" tensor parallel size ({self._tensor_parallel_size})"
)
# Set the actual tensor parallel size to cfg
cfg.tensor_parallel_size = self._tensor_parallel_size
# pipeline parallel size
self._pipeline_parallel_size = min(
cfg.pipeline_parallel_size, self.world_size // cfg.tensor_parallel_size
)
# Set the actual pipeline parallel size to cfg
cfg.pipeline_parallel_size = self._pipeline_parallel_size
if cfg.pipeline_parallel_size > 1:
assert (
try_get_key(cfg, "pipeline_num_layers") is not None
), "cfg.train.dist.pipeline_num_layers must be set when run pipeline parallel"
assert cfg.pipeline_num_layers >= self._pipeline_parallel_size, (
f"number of layers ({cfg.pipeline_num_layers}) is less than"
f" pipeline model parallel size ({self._pipeline_parallel_size})"
)
if try_get_key(cfg, "custom_pipeline_stage_id") is not None:
assert OmegaConf.is_list(
cfg.custom_pipeline_stage_id
), "type of cfg.train.dist.custom_pipeline_stage_id must be list"
cfg.custom_pipeline_stage_id = list(cfg.custom_pipeline_stage_id)
assert max(cfg.custom_pipeline_stage_id) < self._world_size, (
f"the element {max(cfg.custom_pipeline_stage_id)} in"
" cfg.train.dist.custom_pipeline_stage_id is out of range"
f" for total rank {self._world_size}"
)
assert len(cfg.custom_pipeline_stage_id) == cfg.pipeline_num_layers, (
"the length of cfg.train.dist.custom_pipeline_stage_id"
f" {len(cfg.custom_pipeline_stage_id)} must be equal to"
" cfg.train.dist.pipeline_num_layers"
f" {cfg.train.dist.pipeline_num_layers}"
)
else:
# no pipeline parallel, just set 10000
if try_get_key(cfg, "pipeline_num_layers") is None:
cfg.pipeline_num_layers = 10000
self._model_parallel_size = self._pipeline_parallel_size * self._tensor_parallel_size
assert self.world_size % self._model_parallel_size == 0, (
f"world size ({self.world_size}) is not divisible by"
f" tensor model parallel size ({self._tensor_parallel_size}) times"
f" pipeline model parallel size ({self._pipeline_parallel_size})"
)
# data parallel size
self._data_parallel_size = self.world_size // self._model_parallel_size
# Set the actual data parallel size to cfg
cfg.data_parallel_size = self._data_parallel_size
def _init_placement_group(self, cfg):
node_ids = [i // self.num_gpus_per_node for i in range(self.world_size)]
device_ids = list(range(self.num_gpus_per_node)) * self.num_nodes
# [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]
devices = [(n, d) for n, d in zip(node_ids, device_ids)]
num_devices_per_stage = self.world_size // self._pipeline_parallel_size
stages_devices = [
_merge_devices(devices[i : (i + num_devices_per_stage)])
for i in range(0, self.world_size, num_devices_per_stage)
]
# change pipeline_num_layers to make the middle stages contain more layers
if (
self._pipeline_parallel_size >= 4
and cfg.pipeline_num_layers >= 8
and cfg.pipeline_num_layers % self._pipeline_parallel_size == 0
):
temp_num_layers_per_stage = cfg.pipeline_num_layers // self._pipeline_parallel_size
actual_pipeline_num_layers = cfg.pipeline_num_layers + min(
self._pipeline_parallel_size - 1, temp_num_layers_per_stage
)
else:
actual_pipeline_num_layers = cfg.pipeline_num_layers
num_layers_per_stage = actual_pipeline_num_layers // self._pipeline_parallel_size
stage_offset = actual_pipeline_num_layers % self._pipeline_parallel_size
# stage_offset can make the later stages contain more layers when pipeline_num_layers
# cannot be divided by pipeline_parallel_size.
# This can make pipeline parallel more memory efficient.
self._layer_stage_ids = []
for i in range(0, actual_pipeline_num_layers - stage_offset, num_layers_per_stage):
stage_id = i // num_layers_per_stage
if stage_id >= (self._pipeline_parallel_size - stage_offset):
self._layer_stage_ids.append(stage_id)
self._layer_stage_ids.extend([stage_id] * num_layers_per_stage)
self._layer_stage_ids = self._layer_stage_ids[: cfg.pipeline_num_layers]
# when pipeline_parallel_size > 1, we add pipeline_stage_id infomation into cfg
if cfg.pipeline_parallel_size > 1:
cfg.auto_pipeline_stage_id = self._layer_stage_ids
# set pipeline_stage_id by users' setting
if try_get_key(cfg, "custom_pipeline_stage_id") is not None:
self._layer_stage_ids = cfg.custom_pipeline_stage_id
cfg.actual_pipeline_stage_id = self._layer_stage_ids
self._layer_ranks = [stages_devices[stage_id] for stage_id in self._layer_stage_ids]
def _init_parallel_hierarchy(self):
if self.is_data_model_parallel():
self._parallel_hierarchy = (
self._data_parallel_size,
self._tensor_parallel_size,
)
else:
self._parallel_hierarchy = None
@property
def num_nodes(self):
return self._num_nodes
@property
def num_gpus_per_node(self):
return self._num_gpus_per_node
@property
def world_size(self):
return self._world_size
@property
def parallel_hierarchy(self):
return self._parallel_hierarchy
@property
def tensor_parallel_size(self):
return self._tensor_parallel_size
@property
def pipeline_parallel_size(self):
return self._pipeline_parallel_size
@property
def model_parallel_size(self):
return self._tensor_parallel_size
@property
def data_parallel_size(self):
return self._data_parallel_size
@property
def device_type(self):
return self._device_type
def set_device_type(self, device_type):
assert device_type in ["cpu", "cuda"], f"not supported for {device_type}"
self._device_type = device_type
def get_layer_ranks(self, layer_idx):
layer_ranks = self._layer_ranks[layer_idx]
if self._parallel_hierarchy is None:
return layer_ranks
else:
assert len(self._parallel_hierarchy) == 2
return np.asarray(layer_ranks).reshape(self._parallel_hierarchy).tolist()
def get_layer_stage_id(self, layer_idx):
return self._layer_stage_ids[layer_idx]
def is_tensor_model_parallel(self):
return self._tensor_parallel_size > 1
def is_data_parallel(self):
return self._data_parallel_size > 1
def is_pipeline_model_parallel(self):
return self._pipeline_parallel_size > 1
def is_data_model_parallel(self):
return self.is_tensor_model_parallel() and self.is_data_parallel()
def setup_dist_util(cfg):
"""Initialize the distributed environment with configuration.
Example:
.. code-block:: python
from omegaconf import DictConfig
# set the hybrid parallel distributed environment with 2D mesh GPUs
setup_dist_util(
DictConfig(
dict(
data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
)
)
"""
global _DIST_UTIL
_DIST_UTIL = _DistributeUtil(cfg)
def get_dist_util():
"""Get distributed utils if it's been setup. Otherwise, initialize it with
single node/single gpu environment."""
global _DIST_UTIL
if _DIST_UTIL is None:
logger.warning(
"Distributed env is not set up, configure it by default (single node, single gpu)."
)
from omegaconf import DictConfig
setup_dist_util(
DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
)
)
return _DIST_UTIL
def get_layer_placement(layer_idx, device_type=None):
"""
Get ``flow.placement`` object with the initialized distributed environment
according to the ``layer_idx``.
Args:
layer_idx (int): layer index indicating the rank groups. This is very useful for pipeline
parallelism training where different layers are on different ranks.
device_type (str, optional): device type. Defaults to "cuda".
"""
dist_util = get_dist_util()
device_type = dist_util.device_type if device_type is None else device_type
if not flow.cuda.is_available() and device_type == "cuda":
device_type = "cpu"
return flow.placement(
device_type,
dist_util.get_layer_ranks(layer_idx),
)
def get_nd_sbp(sbp_list):
"""Get nd sbp signature list, which is consistent with 1D/2D mesh GPUs.
Args:
sbp_list (list): a sbp list with 2D mesh.
Returns:
A modified sbp list according to the initialized distributed environment.
"""
assert isinstance(sbp_list, list)
assert len(sbp_list) == 2
assert all(isinstance(sbp, flow.sbp.sbp) for sbp in sbp_list)
dist_util = get_dist_util()
if dist_util.is_data_model_parallel():
return sbp_list
elif dist_util.is_data_parallel():
return sbp_list[:1]
elif dist_util.is_tensor_model_parallel():
return sbp_list[1:]
else:
return [flow.sbp.broadcast]
def get_hidden_sbp():
"""Hidden states sbp."""
return get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast])
def get_data_parallel_rank():
dist_util = get_dist_util()
return (flow.env.get_rank() // dist_util.model_parallel_size) % dist_util.data_parallel_size
def get_data_parallel_size():
dist_util = get_dist_util()
return dist_util.data_parallel_size
def get_tensor_parallel_size():
dist_util = get_dist_util()
return dist_util.tensor_parallel_size
def get_pipeline_parallel_size():
dist_util = get_dist_util()
return dist_util.pipeline_parallel_size
def same_sbp(lhs_sbp, rhs_sbp):
"""Determine if two sbp signatures are the same."""
assert len(lhs_sbp) == len(rhs_sbp)
for i in range(len(lhs_sbp)):
if lhs_sbp[i] != rhs_sbp[i]:
return False
return True
def get_rank() -> int:
return flow.env.get_rank()
def get_local_rank() -> int:
return flow.env.get_local_rank()
def is_main_process() -> bool:
return get_rank() == 0
def is_last_process() -> bool:
return get_rank() == get_world_size() - 1
def get_world_size():
return flow.env.get_world_size()
def get_num_nodes():
return flow.env.get_node_size()
def set_device_type(device_type):
dist_util = get_dist_util()
dist_util.set_device_type(device_type)
def broadcast_py_object(obj, src: int = 0):
rank = flow.env.get_rank()
if src == rank:
obj_bytes = dill.dumps(obj)
return dill.loads(flow._oneflow_internal.cpu_broadcast(obj_bytes, src))
else:
return dill.loads(flow._oneflow_internal.cpu_broadcast(None, src))
def convert_to_distributed_default_setting(t):
"""
Helper function to convert all eager local tensor in :attr:`nn.Module` in the model to
global tensor with data parallelism as default.
"""
if not t.is_global:
return t.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=get_layer_placement(0),
)
else:
dist_util = get_dist_util()
device_type = dist_util.device_type
return t.to_global(placement=flow.placement(device_type, ranks=t.placement.ranks))
def ttol(tensor, pure_local=False, ranks=None):
"""Global tensor to local tensor."""
if tensor.is_global:
placement = tensor.placement if not ranks else flow.placement("cuda", ranks)
if pure_local:
tensor = tensor.to_global(placement=placement).to_local()
else:
tensor = tensor.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement
).to_local()
return tensor
def tton(tensor, local_only=False, ranks=None):
"""Global tensor to numpy ndarray."""
if tensor.is_global:
tensor = ttol(tensor, local_only, ranks)
return tensor.numpy()
def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
assert device in ["cpu", "cuda"], f"not supported for device:{device}"
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
tensor = tensor.to_global(
sbp=get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=placement
)
if to_local:
tensor = ttol(tensor)
return tensor
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training.
"""
world_size = get_world_size()
if world_size == 1:
return
flow.comm.barrier()
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from typing import Callable, List, Optional
from urllib import request
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/download.py
# --------------------------------------------------------
def download(url: str, dir: str, *, filename: Optional[str] = None, progress: bool = True) -> str:
"""
Download a file from a given URL to a directory. If file exists, will not
overwrite the existing file.
Args:
url (str):
dir (str): the directory to download the file
filename (str or None): the basename to save the file.
Will use the name in the URL if not given.
progress (bool): whether to use tqdm to draw a progress bar.
Returns:
str: the path to the downloaded file or the existing one.
"""
os.makedirs(dir, exist_ok=True)
if filename is None:
filename = url.split("/")[-1]
assert len(filename), "Cannot obtain filename from url {}".format(url)
fpath = os.path.join(dir, filename)
logger = logging.getLogger(__name__)
if os.path.isfile(fpath):
logger.info("File {} exists! Skipping download.".format(filename))
return fpath
tmp = fpath + ".tmp" # download to a tmp file first, to be more atomic.
try:
logger.info("Downloading from {} ...".format(url))
if progress:
import tqdm
def hook(t: tqdm.tqdm) -> Callable[[int, int, Optional[int]], None]:
last_b: List[int] = [0]
def inner(b: int, bsize: int, tsize: Optional[int] = None) -> None:
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize) # type: ignore
last_b[0] = b
return inner
with tqdm.tqdm( # type: ignore
unit="B", unit_scale=True, miniters=1, desc=filename, leave=True
) as t:
tmp, _ = request.urlretrieve(url, filename=tmp, reporthook=hook(t))
else:
tmp, _ = request.urlretrieve(url, filename=tmp)
statinfo = os.stat(tmp)
size = statinfo.st_size
if size == 0:
raise IOError("Downloaded an empty file from {}!".format(url))
# download to tmp first and move to fpath, to make this function more
# atomic.
shutil.move(tmp, fpath)
except IOError:
logger.error("Failed to download {}".format(url))
raise
finally:
try:
os.unlink(tmp)
except IOError:
pass
logger.info("Successfully downloaded " + fpath + ". " + str(size) + " bytes.")
return fpath
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import logging
import os
import time
from collections import defaultdict
from contextlib import contextmanager
from libai.utils.file_io import PathManager
from libai.utils.history_buffer import HistoryBuffer
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/events.py
# --------------------------------------------------------
__all__ = [
"get_event_storage",
"JSONWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throw an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Example of parsing such a json file:
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 19,
"total_loss": 1.9228371381759644,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 39,
"total_loss": 1.282649278640747,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
if len(to_save):
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `tensorboardX.SummaryWriter(...)`
"""
self._window_size = window_size
from tensorboardX import SummaryWriter
self._writer = SummaryWriter(log_dir=log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# TODO: add write image
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, batch_size, max_iter):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger(__name__)
self._batch_size = batch_size
self._max_iter = max_iter
self._last_write = None
def write(self):
storage = get_event_storage()
iteration = storage.iter
consumed_samples = storage.samples
if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return
try:
data_time = storage.history("data_time").avg(20)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
eta_string = None
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
try:
lr = "{:.2e}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}{iter} {sample} {losses} {time}{data_time} {tpt} lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
iter=f"iteration: {iteration}/{self._max_iter}",
sample=f"consumed_samples: {consumed_samples}",
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(200))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} s/iter ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} s/iter".format(data_time)
if data_time is not None
else "",
tpt="total_throughput: {:.2f} samples/s".format(self._batch_size / iter_time)
if iter_time is not None
else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(HistoryBuffer)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._batch_size = 0
self._current_prefix = ""
self._vis_data = []
self._histograms = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` associated with `img_name` to be shown on
tensorboard.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (flow.Tensor or numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is
3. The image format should be RGB. The elements in img_tensor
can either have values in [0, 1] (float32) or [0, 255] (uint8).
The `img_tensor` will be visualized in tensorboard.
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
def put_scalars(self, *, smoothing_hint=True, **kwargs):
"""
Put multiple scalars from keyword arguments.
Example:
.. code-block:: python
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
"""
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def history(self, name):
"""
Returns:
HistoryBuffer: the scalar history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depending on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should either: (1) Call this function to increment storage.iter when needed.
Or (2) Set `storage.iter` to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
"""
self._iter += 1
@property
def iter(self):
"""
Returns the current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
"""
return self._iter
@iter.setter
def iter(self, val):
self._iter = int(val)
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, val):
self._samples = int(val)
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import concurrent.futures
import errno
import logging
import os
import shutil
import tempfile
import traceback
from collections import OrderedDict
from typing import IO, Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Set, Union
from urllib.parse import urlparse
import portalocker
from libai.utils.download import download
from libai.utils.non_blocking_io import NonBlockingIOManager
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/file_io.py
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/file_io.py
# --------------------------------------------------------
__all__ = ["LazyPath", "PathManager", "get_cache_dir", "file_lock"]
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $LIBAI_CACHE, if set
2) otherwise ~/.oneflow/iopath_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(os.getenv("LIBAI_CACHE", "~/.oneflow/iopath_cache"))
try:
g_pathmgr.mkdirs(cache_dir)
assert os.access(cache_dir, os.W_OK)
except (OSError, AssertionError):
tmp_dir = os.path.join(tempfile.gettempdir(), "iopath_cache")
logger = logging.getLogger(__name__)
logger.warning(f"{cache_dir} is not accessible! Using {tmp_dir} instead!")
cache_dir = tmp_dir
return cache_dir
def file_lock(path: str): # type: ignore
"""
A file lock. Once entered, it is guaranteed that no one else holds the
same lock. Others trying to enter the lock will block for 30 minutes and
raise an exception.
This is useful to make sure workers don't cache files to the same location.
Args:
path (str): a path to be locked. This function will create a lock named
`path + ".lock"`
Examples:
filename = "/path/to/file"
with file_lock(filename):
if not os.path.isfile(filename):
do_create_file()
"""
dirname = os.path.dirname(path)
try:
os.makedirs(dirname, exist_ok=True)
except OSError:
# makedir is not atomic. Exceptions can happen when multiple workers try
# to create the same dir, despite exist_ok=True.
# When this happens, we assume the dir is created and proceed to creating
# the lock. If failed to create the directory, the next line will raise
# exceptions.
pass
return portalocker.Lock(path + ".lock", timeout=3600) # type: ignore
class LazyPath(os.PathLike):
"""
A path that's lazily evaluated when it's used.
Users should be careful to not use it like a str, because
it behaves differently from a str.
Path manipulation functions in Python such as `os.path.*` all accept
PathLike objects already.
It can be materialized to a str using `os.fspath`.
"""
def __init__(self, func: Callable[[], str]) -> None:
"""
Args:
func: a function that takes no arguments and returns the
actual path as a str. It will be called at most once.
"""
self._func = func
self._value: Optional[str] = None
def _get_value(self) -> str:
if self._value is None:
self._value = self._func()
return self._value # pyre-ignore
def __fspath__(self) -> str:
return self._get_value()
# before more like a str after evaluated
def __getattr__(self, name: str): # type: ignore
if self._value is None:
raise AttributeError(f"Uninitialized LazyPath has no attribute: {name}.")
return getattr(self._value, name)
def __getitem__(self, key): # type: ignore
if self._value is None:
raise TypeError("Uninitialized LazyPath is not subscriptable.")
return self._value[key] # type: ignore
def __str__(self) -> str:
if self._value is not None:
return self._value # type: ignore
else:
return super().__str__()
class PathHandler:
"""
PathHandler is a base class that defines common I/O functionality for a URI
protocol. It routes I/O for a generic URI which may look like "protocol://*"
or a canonical filepath "/foo/bar/baz".
"""
_strict_kwargs_check = True
def __init__(
self,
async_executor: Optional[concurrent.futures.Executor] = None,
) -> None:
"""
When registering a `PathHandler`, the user can optionally pass in a
`Executor` to run the asynchronous file operations.
NOTE: For regular non-async operations of `PathManager`, there is
no need to pass `async_executor`.
Args:
async_executor (optional `Executor`): Used for async file operations.
Usage:
```
path_handler = NativePathHandler(async_executor=exe)
path_manager.register_handler(path_handler)
```
"""
self._non_blocking_io_manager = None
self._non_blocking_io_executor = async_executor
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
"""
Checks if the given arguments are empty. Throws a ValueError if strict
kwargs checking is enabled and args are non-empty. If strict kwargs
checking is disabled, only a warning is logged.
Args:
kwargs (Dict[str, Any])
"""
if self._strict_kwargs_check:
if len(kwargs) > 0:
raise ValueError("Unused arguments: {}".format(kwargs))
else:
logger = logging.getLogger(__name__)
for k, v in kwargs.items():
logger.warning("[PathManager] {}={} argument ignored".format(k, v))
def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
raise NotImplementedError()
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, the cache stays on filesystem
(under `file_io.get_cache_dir()`) and will be used by a different run.
Therefore this function is meant to be used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
force(bool): Forces a download from backend if set to True.
Returns:
local_path (str): a file path which exists on the local file system
"""
raise NotImplementedError()
def _copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
"""
Copies a local file to the specified URI.
If the URI is another local path, this should be functionally identical
to copy.
Args:
local_path (str): a file path which exists on the local file system
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing URI
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _opent(
self, path: str, mode: str = "r", buffering: int = 32, **kwargs: Any
) -> Iterable[Any]:
raise NotImplementedError()
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
raise NotImplementedError()
def _opena(
self,
path: str,
mode: str = "r",
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: int = -1,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI with asynchronous permissions.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Usage (default / without callback function):
for n in range(50):
results = run_a_large_task(n)
with path_manager.opena(uri, "w") as f:
f.write(results) # Runs in separate thread
# Main process returns immediately and continues to next iteration
path_manager.async_close()
Usage (advanced / with callback function):
# To write local and then copy to Manifold:
def cb():
path_manager.copy_from_local(
"checkpoint.pt", "manifold://path/to/bucket"
)
f = pm.opena("checkpoint.pt", "wb", callback_after_file_close=cb)
flow.save({...}, f)
f.close()
Args:
...same args as `_open`...
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
Returns:
file: a file-like object with asynchronous methods.
"""
# Restrict mode until `NonBlockingIO` has async read feature.
valid_modes = {"w", "a", "b"}
if not all(m in valid_modes for m in mode):
raise ValueError("`opena` mode must be write or append")
# TODO: Each `PathHandler` should set its own `self._buffered`
# parameter and pass that in here. Until then, we assume no
# buffering for any storage backend.
if not self._non_blocking_io_manager:
self._non_blocking_io_manager = NonBlockingIOManager(
buffered=False,
executor=self._non_blocking_io_executor,
)
try:
return self._non_blocking_io_manager.get_non_blocking_io(
path=self._get_path_with_cwd(path),
io_obj=self._open(path, mode, **kwargs),
callback_after_file_close=callback_after_file_close,
buffering=buffering,
)
except ValueError:
# When `_strict_kwargs_check = True`, then `open_callable`
# will throw a `ValueError`. This generic `_opena` function
# does not check the kwargs since it may include any `_open`
# args like `encoding`, `ttl`, `has_user_data`, etc.
logger = logging.getLogger(__name__)
logger.exception(
"An exception occurred in `NonBlockingIOManager`. This "
"is most likely due to invalid `opena` args. Make sure "
"they match the `open` args for the `PathHandler`."
)
self._async_close()
def _async_join(self, path: Optional[str] = None, **kwargs: Any) -> bool:
"""
Ensures that desired async write threads are properly joined.
Args:
path (str): Pass in a file path to wait until all asynchronous
activity for that path is complete. If no path is passed in,
then this will wait until all asynchronous jobs are complete.
Returns:
status (bool): True on success
"""
if not self._non_blocking_io_manager:
logger = logging.getLogger(__name__)
logger.warning(
"This is an async feature. No threads to join because " "`opena` was not used."
)
self._check_kwargs(kwargs)
return self._non_blocking_io_manager._join(self._get_path_with_cwd(path) if path else None)
def _async_close(self, **kwargs: Any) -> bool:
"""
Closes the thread pool used for the asynchronous operations.
Returns:
status (bool): True on success
"""
if not self._non_blocking_io_manager:
logger = logging.getLogger(__name__)
logger.warning(
"This is an async feature. No threadpool to close because " "`opena` was not used."
)
self._check_kwargs(kwargs)
return self._non_blocking_io_manager._close_thread_pool()
def _copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
raise NotImplementedError()
def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
raise NotImplementedError()
def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
raise NotImplementedError()
def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()
def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Symlink the src_path to the dst_path
Args:
src_path (str): A URI supported by this PathHandler to symlink from
dst_path (str): A URI supported by this PathHandler to symlink to
"""
raise NotImplementedError()
def _set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
"""
Set the current working directory. PathHandler classes prepend the cwd
to all URI paths that are handled.
Args:
path (str) or None: A URI supported by this PathHandler. Must be a valid
absolute path or None to set the cwd to None.
Returns:
bool: true if cwd was set without errors
"""
raise NotImplementedError()
def _get_path_with_cwd(self, path: str) -> str:
"""
Default implementation. PathHandler classes that provide a `_set_cwd`
feature should also override this `_get_path_with_cwd` method.
Args:
path (str): A URI supported by this PathHandler.
Returns:
path (str): Full path with the cwd attached.
"""
return path
class NativePathHandler(PathHandler):
"""
Handles paths that can be accessed using Python native system calls. This
handler uses `open()` and `os.*` calls on the given path.
"""
_cwd = None
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
self._check_kwargs(kwargs)
return os.fspath(path)
def _copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
self._check_kwargs(kwargs)
local_path = self._get_path_with_cwd(local_path)
dst_path = self._get_path_with_cwd(dst_path)
assert self._copy(src_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs)
def _open(
self,
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a path.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy works as follows:
* Binary files are buffered in fixed-size chunks; the size of
the buffer is chosen using a heuristic trying to determine the
underlying device’s “block size” and falling back on
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
typically be 4096 or 8192 bytes long.
encoding (Optional[str]): the name of the encoding used to decode or
encode the file. This should only be used in text mode.
errors (Optional[str]): an optional string that specifies how encoding
and decoding errors are to be handled. This cannot be used in binary
mode.
newline (Optional[str]): controls how universal newlines mode works
(it only applies to text mode). It can be None, '', '\n', '\r',
and '\r\n'.
closefd (bool): If closefd is False and a file descriptor rather than
a filename was given, the underlying file descriptor will be kept
open when the file is closed. If a filename is given closefd must
be True (the default) otherwise an error will be raised.
opener (Optional[Callable]): A custom opener can be used by passing
a callable as opener. The underlying file descriptor for the file
object is then obtained by calling opener with (file, flags).
opener must return an open file descriptor (passing os.open as opener
results in functionality similar to passing None).
See https://docs.python.org/3/library/functions.html#open for details.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
return open( # type: ignore
self._get_path_with_cwd(path),
mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
closefd=closefd,
opener=opener,
)
def _copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
if os.path.exists(dst_path) and not overwrite:
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.copyfile(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in file copy - {}".format(str(e)))
return False
def _mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
if os.path.exists(dst_path):
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.move(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in move operation - {}".format(str(e)))
return False
def _symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Creates a symlink to the src_path at the dst_path
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
src_path = self._get_path_with_cwd(src_path)
dst_path = self._get_path_with_cwd(dst_path)
logger = logging.getLogger(__name__)
if not os.path.exists(src_path):
logger.error("Source path {} does not exist".format(src_path))
return False
if os.path.exists(dst_path):
logger.error("Destination path {} already exists.".format(dst_path))
return False
try:
os.symlink(src_path, dst_path)
return True
except Exception as e:
logger.error("Error in symlink - {}".format(str(e)))
return False
def _exists(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.exists(self._get_path_with_cwd(path))
def _isfile(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isfile(self._get_path_with_cwd(path))
def _isdir(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isdir(self._get_path_with_cwd(path))
def _ls(self, path: str, **kwargs: Any) -> List[str]:
self._check_kwargs(kwargs)
return os.listdir(self._get_path_with_cwd(path))
def _mkdirs(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
# EEXIST it can still happen if multiple processes are creating the dir
if e.errno != errno.EEXIST:
raise
def _rm(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
os.remove(path)
def _set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
# Remove cwd path if None
if path is None:
self._cwd = None
return True
# Make sure path is a valid Unix path
if not os.path.exists(path):
raise ValueError(f"{path} is not a valid Unix path")
# Make sure path is an absolute path
if not os.path.isabs(path):
raise ValueError(f"{path} is not an absolute path")
self._cwd = path
return True
def _get_path_with_cwd(self, path: str) -> str:
return os.path.normpath(path if not self._cwd else os.path.join(self._cwd, path))
class HTTPURLHandler(PathHandler):
"""
Download URLs and cache them to disk.
"""
def __init__(self) -> None:
self.cache_map: Dict[str, str] = {}
def _get_supported_prefixes(self) -> List[str]:
return ["http://", "https://", "ftp://"]
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
self._check_kwargs(kwargs)
if force or path not in self.cache_map or not os.path.exists(self.cache_map[path]):
logger = logging.getLogger(__name__)
parsed_url = urlparse(path)
dirname = os.path.join(get_cache_dir(), os.path.dirname(parsed_url.path.lstrip("/")))
filename = path.split("/")[-1]
cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logger.info("Downloading {} ...".format(path))
cached = download(path, dirname, filename=filename)
logger.info("URL {} cached in {}".format(path, cached))
self.cache_map[path] = cached
return self.cache_map[path]
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a remote HTTP path. The resource is first downloaded and cached
locally.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): Not used for this PathHandler.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)
class OneDrivePathHandler(HTTPURLHandler):
"""
Map OneDrive (short) URLs to direct download links
"""
ONE_DRIVE_PREFIX = "https://1drv.ms/u/s!"
def create_one_drive_direct_download(self, one_drive_url: str) -> str:
"""
Converts a short OneDrive URI into a download link that can be used with wget
Args:
one_drive_url (str): A OneDrive URI supported by this PathHandler
Returns:
result_url (str): A direct download URI for the file
"""
data_b64 = base64.b64encode(bytes(one_drive_url, "utf-8"))
data_b64_string = data_b64.decode("utf-8").replace("/", "_").replace("+", "-").rstrip("=")
result_url = f"https://api.onedrive.com/v1.0/shares/u!{data_b64_string}/root/content"
return result_url
def _get_supported_prefixes(self) -> List[str]:
return [self.ONE_DRIVE_PREFIX]
def _get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
This implementation downloads the remote resource and caches it locally.
The resource will only be downloaded if not previously requested.
"""
logger = logging.getLogger(__name__)
direct_url = self.create_one_drive_direct_download(path)
logger.info(f"URL {path} mapped to direct download link {direct_url}")
return super()._get_local_path(os.fspath(direct_url), force=force, **kwargs)
class PathManagerBase:
"""
A class for users to open generic paths or translate generic paths to file names.
path_manager.method(path) will do the following:
1. Find a handler by checking the prefixes in `self._path_handlers`.
2. Call handler.method(path) on the handler that's found
"""
def __init__(self) -> None:
self._path_handlers: MutableMapping[str, PathHandler] = OrderedDict()
"""
Dict for path prefix to handler
"""
self._native_path_handler: PathHandler = NativePathHandler()
"""
A NativePathHandler that works on posix paths. This is used as the fallback.
"""
self._cwd: Optional[str] = None
"""
Keeps track of the single cwd (if set).
NOTE: Only one PathHandler can have a cwd set at a time.
"""
self._async_handlers: Set[PathHandler] = set()
"""
Keeps track of the PathHandler subclasses where `opena` was used so
all of the threads can be properly joined when calling
`PathManager.join`.
"""
def __get_path_handler(self, path: Union[str, os.PathLike]) -> PathHandler:
"""
Finds a PathHandler that supports the given path. Falls back to the native
PathHandler if no other handler is found.
Args:
path (str or os.PathLike): URI path to resource
Returns:
handler (PathHandler)
"""
path = os.fspath(path) # pyre-ignore
for p in self._path_handlers.keys():
if path.startswith(p):
return self._path_handlers[p]
return self._native_path_handler
def opent(
self, path: str, mode: str = "r", buffering: int = 32, **kwargs: Any
) -> Iterable[Any]:
"""
Open a tabular data source. Only reading is supported.
The opent() returns a Python iterable collection object, compared to
bytes/text data with open()
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'
buffering (int): number of rows fetched and cached
Returns:
An iterable collection object.
"""
return self.__get_path_handler(path)._opent(path, mode, buffering, **kwargs)
def open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
return self.__get_path_handler(path)._open( # type: ignore
path, mode, buffering=buffering, **kwargs
)
# NOTE: This feature is only implemented for `NativePathHandler` and can
# currently only be used in write mode.
def opena(
self,
path: str,
mode: str = "r",
buffering: int = -1,
callback_after_file_close: Optional[Callable[[None], None]] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a file with asynchronous permissions. `f.write()` calls (and
potentially `f.read()` calls in the future) will be dispatched
asynchronously such that the main program can continue running.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Usage (default / without callback function):
for n in range(50):
results = run_a_large_task(n)
# `f` is a file-like object with asynchronous methods
with path_manager.opena(uri, "w") as f:
f.write(results) # Runs in separate thread
# Main process returns immediately and continues to next iteration
path_manager.async_close()
Usage (advanced / with callback function):
# To asynchronously write to Manifold:
def cb():
path_manager.copy_from_local(
"checkpoint.pt", "manifold://path/to/bucket"
)
f = pm.opena("checkpoint.pt", "wb", callback_after_file_close=cb)
oneflow.save({...}, f)
f.close()
Args:
...
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
Returns:
file: a file-like object with asynchronous methods.
"""
non_blocking_io = self.__get_path_handler(path)._opena(
path,
mode,
buffering=buffering,
callback_after_file_close=callback_after_file_close,
**kwargs,
)
# Keep track of the path handlers where `opena` is used so that all of the
# threads can be properly joined on `PathManager.join`.
self._async_handlers.add(self.__get_path_handler(path))
return non_blocking_io
def async_join(self, *paths: str, **kwargs: Any) -> bool:
"""
Ensures that desired async write threads are properly joined.
Usage:
Wait for asynchronous methods operating on specific file paths to
complete.
async_join("path/to/file1.txt")
async_join("path/to/file2.txt", "path/to/file3.txt")
Wait for all asynchronous methods to complete.
async_join()
Args:
*paths (str): Pass in any number of file paths and `async_join` will wait
until all asynchronous activity for those paths is complete. If no
paths are passed in, then `async_join` will wait until all asynchronous
jobs are complete.
Returns:
status (bool): True on success
"""
success = True
if not paths: # Join all.
for handler in self._async_handlers:
success = handler._async_join(**kwargs) and success
else: # Join specific paths.
for path in paths:
success = self.__get_path_handler(path)._async_join(path, **kwargs) and success
return success
def async_close(self, **kwargs: Any) -> bool:
"""
`async_close()` must be called at the very end of any script that uses the
asynchronous `opena` feature. This calls `async_join()` first and then closes
the thread pool used for the asynchronous operations.
Returns:
status (bool): True on success
"""
success = self.async_join(**kwargs)
for handler in self._async_handlers:
success = handler._async_close(**kwargs) and success
self._async_handlers.clear()
return success
def copy(self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
# Copying across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
)
return self.__get_path_handler(src_path)._copy(src_path, dst_path, overwrite, **kwargs)
def mv(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""
Moves (renames) a source path supported by NativePathHandler to
a destination path.
Args:
src_path (str): A URI supported by NativePathHandler
dst_path (str): A URI supported by NativePathHandler
Returns:
status (bool): True on success
Exception:
Asserts if both the src and dest paths are not supported by
NativePathHandler.
"""
# Moving across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
), "Src and dest paths must be supported by the same path handler."
return self.__get_path_handler(src_path)._mv(src_path, dst_path, **kwargs)
def get_local_path(self, path: str, force: bool = False, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk.
Args:
path (str): A URI supported by this PathHandler
force(bool): Forces a download from backend if set to True.
Returns:
local_path (str): a file path which exists on the local file system
"""
path = os.fspath(path)
return self.__get_path_handler(path)._get_local_path( # type: ignore
path, force=force, **kwargs
)
def copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> None:
"""
Copies a local file to the specified URI.
If the URI is another local path, this should be functionally identical
to copy.
Args:
local_path (str): a file path which exists on the local file system
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing URI
Returns:
status (bool): True on success
"""
assert os.path.exists(local_path)
return self.__get_path_handler(dst_path)._copy_from_local(
local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
)
def exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return self.__get_path_handler(path)._exists(path, **kwargs) # type: ignore
def isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
return self.__get_path_handler(path)._isfile(path, **kwargs) # type: ignore
def isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return self.__get_path_handler(path)._isdir(path, **kwargs) # type: ignore
def ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
return self.__get_path_handler(path)._ls(path, **kwargs)
def mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
return self.__get_path_handler(path)._mkdirs(path, **kwargs) # type: ignore
def rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
return self.__get_path_handler(path)._rm(path, **kwargs) # type: ignore
def symlink(self, src_path: str, dst_path: str, **kwargs: Any) -> bool:
"""Symlink the src_path to the dst_path
Args:
src_path (str): A URI supported by this PathHandler to symlink from
dst_path (str): A URI supported by this PathHandler to symlink to
"""
# Copying across handlers is not supported.
assert self.__get_path_handler(src_path) == self.__get_path_handler( # type: ignore
dst_path
)
return self.__get_path_handler(src_path)._symlink(src_path, dst_path, **kwargs)
def set_cwd(self, path: Union[str, None], **kwargs: Any) -> bool:
"""
Set the current working directory. PathHandler classes prepend the cwd
to all URI paths that are handled.
Args:
path (str) or None: A URI supported by this PathHandler. Must be a valid
absolute Unix path or None to set the cwd to None.
Returns:
bool: true if cwd was set without errors
"""
if path is None and self._cwd is None:
return True
if self.__get_path_handler(path or self._cwd)._set_cwd(path, **kwargs): # type: ignore
self._cwd = path
return True
return False
def register_handler(self, handler: PathHandler, allow_override: bool = False) -> None:
"""
Register a path handler associated with `handler._get_supported_prefixes`
URI prefixes.
Args:
handler (PathHandler)
allow_override (bool): allow overriding existing handler for prefix
"""
logger = logging.getLogger(__name__)
assert isinstance(handler, PathHandler), handler
# Allow override of `NativePathHandler` which is automatically
# instantiated by `PathManager`.
if isinstance(handler, NativePathHandler):
if allow_override:
self._native_path_handler = handler
else:
raise ValueError(
"`NativePathHandler` is registered by default. Use the "
"`allow_override=True` kwarg to override it."
)
return
for prefix in handler._get_supported_prefixes():
if prefix not in self._path_handlers:
self._path_handlers[prefix] = handler
continue
old_handler_type = type(self._path_handlers[prefix])
if allow_override:
# if using the global PathManager, show the warnings
global g_pathmgr
if self == g_pathmgr:
logger.warning(
f"[PathManager] Attempting to register prefix '{prefix}' from "
"the following call stack:\n" + "".join(traceback.format_stack(limit=5))
# show the most recent callstack
)
logger.warning(
f"[PathManager] Prefix '{prefix}' is already registered "
f"by {old_handler_type}. We will override the old handler. "
"To avoid such conflicts, create a project-specific PathManager "
"instead."
)
self._path_handlers[prefix] = handler
else:
raise KeyError(
f"[PathManager] Prefix '{prefix}' already registered by {old_handler_type}!"
)
# Sort path handlers in reverse order so longer prefixes take priority,
# eg: http://foo/bar before http://foo
self._path_handlers = OrderedDict(
sorted(self._path_handlers.items(), key=lambda t: t[0], reverse=True)
)
def set_strict_kwargs_checking(self, enable: bool) -> None:
"""
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
unused parameters are passed to a PathHandler function. If disabled, only
a warning is given.
With a centralized file API, there's a tradeoff of convenience and
correctness delegating arguments to the proper I/O layers. An underlying
`PathHandler` may support custom arguments which should not be statically
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
may want to expose a `cache_timeout` argument for `open()` which specifies
how old a locally cached resource can be before it's refetched from the
remote server. This argument would not make sense for a `NativePathHandler`.
If strict kwargs checking is disabled, `cache_timeout` can be passed to
`PathManager.open` which will forward the arguments to the underlying
handler. By default, checking is enabled since it is innately unsafe:
multiple `PathHandler`s could reuse arguments with different semantic
meanings or types.
Args:
enable (bool)
"""
self._native_path_handler._strict_kwargs_check = enable
for handler in self._path_handlers.values():
handler._strict_kwargs_check = enable
class PathManagerFactory:
"""
PathManagerFactory is the class responsible for creating new PathManager
instances and removing them when no longer needed.
PathManager can be instantiated directly too, but it is recommended that
you use PathManagerFactory to create them.
"""
GLOBAL_PATH_MANAGER = "global_path_manager"
pm_list = {}
@staticmethod
def get(key=GLOBAL_PATH_MANAGER) -> PathManagerBase:
"""
Get the path manager instance associated with a key.
A new instance will be created if there is no existing
instance associated with the key passed in.
Args:
key (str):
"""
if key not in PathManagerFactory.pm_list:
PathManagerFactory.pm_list[key] = PathManagerBase()
return PathManagerFactory.pm_list[key]
@staticmethod
def remove(key):
"""
Remove the path manager instance associated with a key.
Args:
key (str):
"""
if key in PathManagerFactory.pm_list:
_pm = PathManagerFactory.pm_list.pop(key) # noqa
del _pm
"""
A global instance of PathManager.
This global instance is provided for backward compatibility, but it is
recommended that clients use PathManagerFactory
"""
g_pathmgr = PathManagerFactory.get()
PathManager = PathManagerBase()
PathManager.register_handler(HTTPURLHandler())
PathManager.register_handler(OneDrivePathHandler())
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import fnmatch
import hashlib
import json
import logging
import os
import shutil
import sys
import tempfile
from functools import wraps
from io import open
from pathlib import Path
import boto3
import requests
import wget
from botocore.config import Config
from botocore.exceptions import ClientError
from tqdm import tqdm
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
cache_home = Path(os.getenv("OF_CACHE_ROOT", Path.home() / ".of_cache"))
default_cache_path = str(cache_home / "libai")
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3
/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
url_hash = hashlib.sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = hashlib.sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None):
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url, proxies=proxies)
else:
try:
response = requests.head(
url, allow_redirects=True, proxies=proxies, timeout=etag_timeout
)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
etag = None
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s",
url,
temp_file.name,
)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def get_md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
result = hash_md5.hexdigest()
return result
def download_file(out_path: str, url):
logger.info(f"downloading from {url} to {out_path}")
wget.download(url, out=out_path)
def get_data_from_cache(url, cache_dir=None, force_download=False, md5=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = default_cache_path
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
filename = url.split("/")[-1]
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we have already get the file, just check the md5 if provided
if os.path.exists(cache_path) and md5 is not None:
local_file_md5 = get_md5(cache_path)
if local_file_md5 != md5:
os.unlink(cache_path)
download_file(cache_path, url)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path):
download_file(cache_path, url)
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s",
url,
temp_file.name,
)
# GET file object
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import numpy as np
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/history_buffer.py
# --------------------------------------------------------
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment