Unverified Commit 08f534d2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Doc styling (#8067)

* Important files

* Styling them all

* Revert "Styling them all"

This reverts commit 7d029395fdae8513b8281cbc2a6c239f8093503e.

* Syling them for realsies

* Fix syntax error

* Fix benchmark_utils

* More fixes

* Fix modeling auto and script

* Remove new line

* Fixes

* More fixes

* Fix more files

* Style

* Add FSMT

* More fixes

* More fixes

* More fixes

* More fixes

* Fixes

* More fixes

* More fixes

* Last fixes

* Make sphinx happy
parent 04a17f85
......@@ -62,8 +62,8 @@ SEG_ID_PAD = 4
class XLNetTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" XLNet tokenizer (backed by HuggingFace's `tokenizers` library). Based on
`SentencePiece <https://github.com/google/sentencepiece>`__.
Construct a "fast" XLNet tokenizer (backed by HuggingFace's `tokenizers` library). Based on `SentencePiece
<https://github.com/google/sentencepiece>`__.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
......@@ -83,28 +83,27 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
.. note::
When building a sequence using special tokens, this is not the token that is used for the beginning
of sequence. The token used is the :obj:`cls_token`.
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the :obj:`cls_token`.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sequence token.
.. note::
When building a sequence using special tokens, this is not the token that is used for the end
of sequence. The token used is the :obj:`sep_token`.
When building a sequence using special tokens, this is not the token that is used for the end of
sequence. The token used is the :obj:`sep_token`.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"<sep>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
for sequence classification or for a text and a question for question answering.
It is also used as the last token of a sequence built with special tokens.
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<cls>"`):
The classifier token which is used when doing sequence classification (classification of the whole
sequence instead of per-token classification). It is the first token of the sequence when built with
special tokens.
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
......@@ -166,9 +165,8 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
An XLNet sequence has the following format:
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLNet sequence has the following format:
- single sequence: ``X <sep> <cls>``
- pair of sequences: ``A <sep> B <sep> <cls>``
......@@ -223,8 +221,8 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
An XLNet sequence pair mask has the following format:
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
sequence pair mask has the following format:
::
......
......@@ -156,8 +156,7 @@ logger = logging.get_logger(__name__)
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch,
optimized for 🤗 Transformers.
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Args:
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
......@@ -169,18 +168,19 @@ class Trainer:
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
they work the same way as the 🤗 Transformers models.
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
The arguments to tweak for training. Will default to a basic instance of
:class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
the current directory if not provided.
data_collator (:obj:`DataCollator`, `optional`):
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
......@@ -189,22 +189,20 @@ class Trainer:
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose
different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
layers, dropout probabilities etc). compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in :doc:`here <callback>`.
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values. callbacks (List of
:obj:`~transformers.TrainerCallback`, `optional`): A list of callbacks to customize the training loop. Will
add those to the list of default callbacks detailed in :doc:`here <callback>`.
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
kwargs:
Deprecated keyword arguments.
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. kwargs: Deprecated keyword
arguments.
"""
def __init__(
......@@ -395,8 +393,8 @@ class Trainer:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
(adapted to distributed training if necessary) otherwise.
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
......@@ -985,8 +983,10 @@ class Trainer:
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
more information see:
- the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
- the documentation of `optuna.create_study
<https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
- the documentation of `tune.run
<https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
Returns:
:class:`transformers.trainer_utils.BestRun`: All the information about the best run.
......@@ -1124,8 +1124,8 @@ class Trainer:
def is_local_master(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
machines) main process.
.. warning::
......@@ -1136,8 +1136,8 @@ class Trainer:
def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
machines) main process.
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
......@@ -1146,8 +1146,8 @@ class Trainer:
def is_world_master(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
.. warning::
......@@ -1158,8 +1158,8 @@ class Trainer:
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
......@@ -1267,16 +1267,16 @@ class Trainer:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are
task-dependent (pass it to the init :obj:`compute_metrics` argument).
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
the :obj:`__len__` method.
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
......@@ -1301,22 +1301,20 @@ class Trainer:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels.
In that case, this method will also return metrics, like in :obj:`evaluate()`.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
Returns:
`NamedTuple`:
predictions (:obj:`np.ndarray`):
The predictions on :obj:`test_dataset`.
label_ids (:obj:`np.ndarray`, `optional`):
The labels (if the dataset contained some).
metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels).
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
......@@ -1465,8 +1463,8 @@ class Trainer:
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
......@@ -1507,9 +1505,9 @@ class Trainer:
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PreTrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
floating point operations for every backward + forward pass. If using another model, either implement such a
method in the model or subclass and override this method.
Args:
model (:obj:`nn.Module`):
......
......@@ -64,11 +64,11 @@ class TrainerState:
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
This will impact the way data will be logged in TensorBoard.
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
impact the way data will be logged in TensorBoard.
"""
epoch: Optional[float] = None
......
......@@ -135,14 +135,12 @@ def torch_distributed_zero_first(local_rank: int):
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
Distributed Sampler that subsamples indicies sequentially, making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training), which means that the model params won't
have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add
extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather`
or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
......@@ -203,16 +201,15 @@ def nested_truncate(tensors, limit):
class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU
by chunks.
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on
CPU at every step, our sampler will generate the following indices:
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
step, our sampler will generate the following indices:
:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then
process 0, 1 and 2 will be responsible of making predictions for the following samples:
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
2 will be responsible of making predictions for the following samples:
- P0: :obj:`[0, 1, 2, 3, 4, 5]`
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
......@@ -224,13 +221,13 @@ class DistributedTensorGatherer:
- P1: :obj:`[6, 7]`
- P2: :obj:`[12, 13]`
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor)
corresponding to the following indices:
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
the following indices:
:obj:`[0, 1, 6, 7, 12, 13]`
If we directly concatenate our results without taking any precautions, the user will then get
the predictions for the indices in this order at the end of the prediction loop:
If we directly concatenate our results without taking any precautions, the user will then get the predictions for
the indices in this order at the end of the prediction loop:
:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
......
......@@ -30,8 +30,7 @@ logger = logging.get_logger(__name__)
class TFTrainer:
"""
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
optimized for 🤗 Transformers.
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, optimized for 🤗 Transformers.
Args:
model (:class:`~transformers.TFPreTrainedModel`):
......@@ -40,15 +39,15 @@ class TFTrainer:
The arguments to tweak training.
train_dataset (:class:`~tf.data.Dataset`, `optional`):
The dataset to use for training. The dataset should yield tuples of ``(features, labels)`` where
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is
calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as when
using a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss
is calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as
when using a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
``model(features, **labels)``.
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
The dataset to use for evaluation. The dataset should yield tuples of ``(features, labels)`` where
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is
calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as when
using a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss
is calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as
when using a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
``model(features, **labels)``.
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
......@@ -59,8 +58,8 @@ class TFTrainer:
A tuple containing the optimizer and the scheduler to use. The optimizer default to an instance of
:class:`tf.keras.optimizers.Adam` if :obj:`args.weight_decay_rate` is 0 else an instance of
:class:`~transformers.AdamWeightDecay`. The scheduler will default to an instance of
:class:`tf.keras.optimizers.schedules.PolynomialDecay` if :obj:`args.num_warmup_steps` is 0 else
an instance of :class:`~transformers.WarmUp`.
:class:`tf.keras.optimizers.schedules.PolynomialDecay` if :obj:`args.num_warmup_steps` is 0 else an
instance of :class:`~transformers.WarmUp`.
kwargs:
Deprecated keyword arguments.
"""
......@@ -155,10 +154,10 @@ class TFTrainer:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
If provided, will override `self.eval_dataset`. The dataset should yield tuples of ``(features,
labels)`` where ``features`` is a dict of input features and ``labels`` is the labels. If ``labels``
is a tensor, the loss is calculated by the model by calling ``model(features, labels=labels)``. If
``labels`` is a dict, such as when using a QuestionAnswering head model with multiple targets, the
loss is instead calculated by calling ``model(features, **labels)``.
labels)`` where ``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is
a tensor, the loss is calculated by the model by calling ``model(features, labels=labels)``. If
``labels`` is a dict, such as when using a QuestionAnswering head model with multiple targets, the loss
is instead calculated by calling ``model(features, **labels)``.
Subclass and override this method if you want to inject some custom behavior.
"""
......@@ -187,11 +186,11 @@ class TFTrainer:
Args:
test_dataset (:class:`~tf.data.Dataset`):
The dataset to use. The dataset should yield tuples of ``(features, labels)`` where ``features`` is
a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is
calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such
as when using a QuestionAnswering head model with multiple targets, the loss is instead calculated
by calling ``model(features, **labels)``.
The dataset to use. The dataset should yield tuples of ``(features, labels)`` where ``features`` is a
dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is calculated
by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as when using
a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
``model(features, **labels)``.
Subclass and override this method if you want to inject some custom behavior.
"""
......@@ -234,14 +233,15 @@ class TFTrainer:
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
One can subclass and override this method to customize the setup if needed. Find more information `here
<https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_PROJECT:
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different
project.
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely.
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
......@@ -266,8 +266,8 @@ class TFTrainer:
COMET_OFFLINE_DIRECTORY:
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
For a number of configurable items in the environment, see `here
<https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
"""
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
......@@ -419,14 +419,14 @@ class TFTrainer:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are
task-dependent (pass it to the init :obj:`compute_metrics` argument).
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. The dataset should yield tuples of
``(features, labels)`` where ``features`` is a dict of input features and ``labels`` is the labels.
If ``labels`` is a tensor, the loss is calculated by the model by calling ``model(features,
``(features, labels)`` where ``features`` is a dict of input features and ``labels`` is the labels. If
``labels`` is a tensor, the loss is calculated by the model by calling ``model(features,
labels=labels)``. If ``labels`` is a dict, such as when using a QuestionAnswering head model with
multiple targets, the loss is instead calculated by calling ``model(features, **labels)``.
......@@ -753,24 +753,23 @@ class TFTrainer:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels.
In that case, this method will also return metrics, like in :obj:`evaluate()`.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:class:`~tf.data.Dataset`):
Dataset to run the predictions on. The dataset should yield tuples of ``(features, labels)`` where
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor,
the loss is calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is
a dict, such as when using a QuestionAnswering head model with multiple targets, the loss is instead
calculated by calling ``model(features, **labels)``.
Returns:
`NamedTuple`:
predictions (:obj:`np.ndarray`):
The predictions on :obj:`test_dataset`.
label_ids (:obj:`np.ndarray`, `optional`):
The labels (if the dataset contained some).
metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels).
``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the
loss is calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict,
such as when using a QuestionAnswering head model with multiple targets, the loss is instead calculated
by calling ``model(features, **labels)``
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
......
......@@ -28,8 +28,8 @@ from .tokenization_utils_base import ExplicitEnum
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
(if installed).
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
installed).
Args:
seed (:obj:`int`): The seed to set.
......
......@@ -35,11 +35,11 @@ def default_logdir() -> str:
@dataclass
class TrainingArguments:
"""
TrainingArguments is the subset of the arguments we use in our example scripts
**which relate to the training loop itself**.
TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
itself**.
Using :class:`~transformers.HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on the command line.
Using :class:`~transformers.HfArgumentParser` we can turn this class into argparse arguments to be able to specify
them on the command line.
Parameters:
output_dir (:obj:`str`):
......@@ -128,7 +128,8 @@ class TrainingArguments:
Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
same value as :obj:`logging_steps` if not set.
dataloader_num_workers (:obj:`int`, `optional`, defaults to 0):
Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.
Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the
main process.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
......@@ -143,15 +144,14 @@ class TrainingArguments:
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
label_names (:obj:`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
(Note that this behavior is not implemented for :class:`~transformers.TFTrainer` yet.) label_names
(:obj:`List[str]`, `optional`): The list of keys in your dictionary of inputs that correspond to the
labels.
Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to load the best model found during training at the end of training.
:obj:`XxxForQuestionAnswering` in which case it will default to :obj:`["start_positions",
"end_positions"]`. load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or
not to load the best model found during training at the end of training.
.. note::
......@@ -164,10 +164,9 @@ class TrainingArguments:
loss).
If you set this value, :obj:`greater_is_better` will default to :obj:`True`. Don't forget to set it to
:obj:`False` if your metric is better when lower.
greater_is_better (:obj:`bool`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
models should have a greater metric or not. Will default to:
:obj:`False` if your metric is better when lower. greater_is_better (:obj:`bool`, `optional`) Use in
conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better models
should have a greater metric or not. Will default to:
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
:obj:`"eval_loss"`.
......
......@@ -16,11 +16,11 @@ if is_tf_available():
@dataclass
class TFTrainingArguments(TrainingArguments):
"""
TrainingArguments is the subset of the arguments we use in our example scripts
**which relate to the training loop itself**.
TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
itself**.
Using :class:`~transformers.HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on the command line.
Using :class:`~transformers.HfArgumentParser` we can turn this class into argparse arguments to be able to specify
them on the command line.
Parameters:
output_dir (:obj:`str`):
......
......@@ -44,8 +44,8 @@ _default_log_level = logging.WARNING
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level.
If it is not - fall back to ``_default_log_level``
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
......@@ -194,8 +194,8 @@ def enable_default_handler() -> None:
def disable_propagation() -> None:
"""Disable propagation of the library log outputs.
Note that log propagation is disabled by default.
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
......@@ -203,9 +203,9 @@ def disable_propagation() -> None:
def enable_propagation() -> None:
"""Enable propagation of the library log outputs.
Please disable the HuggingFace Transformers's default handler to prevent double logging if the root logger has
been configured.
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
......
......@@ -30,7 +30,7 @@ def format_time(t):
def html_progress_bar(value, total, prefix, label, width=300):
"Html code for a progress bar `value`/`total` with `label` on the right, `prefix` on the left."
# docstyle-ignore
return f"""
<div>
<style>
......@@ -71,11 +71,12 @@ class NotebookProgressBar:
A progress par for display in a notebook.
Class attributes (overridden by derived classes)
- **warmup** (:obj:`int`) -- The number of iterations to do at the beginning while ignoring
:obj:`update_every`.
- **update_every** (:obj:`float`) -- Since calling the time takes some time, we only do it
every presumed :obj:`update_every` seconds. The progress bar uses the average time passed
up until now to guess the next value for which it will call the update.
- **update_every** (:obj:`float`) -- Since calling the time takes some time, we only do it every presumed
:obj:`update_every` seconds. The progress bar uses the average time passed up until now to guess the next
value for which it will call the update.
Args:
total (:obj:`int`):
......@@ -245,8 +246,8 @@ class NotebookTrainingTracker(NotebookProgressBar):
def add_child(self, total, prefix=None, width=300):
"""
Add a child progress bar disaplyed under the table of metrics. The child progress bar is returned (so it can
be easily updated).
Add a child progress bar disaplyed under the table of metrics. The child progress bar is returned (so it can be
easily updated).
Args:
total (:obj:`int`): The number of iterations for the child progress bar.
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Style utils for the .rst and the docstrings."""
import argparse
import os
import re
import warnings
from enum import Enum
# Special blocks where the inside should be formatted.
TEXTUAL_BLOCKS = ["note", "warning"]
# List of acceptable characters for titles and sections underline.
TITLE_SPECIAL_CHARS = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
# Special words for docstrings (s? means the s is optional)
DOC_SPECIAL_WORD = [
"Args?",
"Params?",
"Parameters?",
"Arguments?",
"Examples?",
"Usage",
"Returns?",
"Raises?",
"Attributes?",
]
# Regexes
# Matches any declaration of textual block, like `.. note::`. (ignore case to avoid writing all versions in the list)
_re_textual_blocks = re.compile(r"^\s*\.\.\s+(" + "|".join(TEXTUAL_BLOCKS) + r")\s*::\s*$", re.IGNORECASE)
# Matches list introduction in rst.
_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+.\s+)")
# Matches the indent in a line.
_re_indent = re.compile(r"^(\s*)\S")
# Matches a table declaration in rst.
_re_table = re.compile(r"(\+-+)+\+\s*$")
# Matches a code block in rst `:: `.
_re_code_block = re.compile(r"^\s*::\s*$")
# Matches any block of the form `.. something::` or `.. something:: bla`.
_re_ignore = re.compile(r"^\s*\.\.\s+(\S+)\s*::\s*\S*\s*$")
# Matches comment introduction in rst.
_re_comment = re.compile(r"\s*\.\.\s*$")
# Matches the special tag to ignore some paragraphs.
_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore")
# Matches the example introduction in docstrings.
_re_example = re.compile(r"::\s*$")
# Matches the parameters introduction in docstrings.
_re_arg_def = re.compile(r"^\s*(Args?|Parameters?|Params|Arguments?|Environment|Attributes?)\s*:\s*$")
# Matches the return introduction in docstrings.
_re_return = re.compile(r"^\s*(Returns?|Raises?|Note)\s*:\s*$")
# Matches any doc special word without an empty line before.
_re_any_doc_special_word = re.compile(r"[^\n]\n([ \t]*)(" + "|".join(DOC_SPECIAL_WORD) + r")(::?\s*)\n")
class SpecialBlock(Enum):
NOT_SPECIAL = 0
NO_STYLE = 1
ARG_LIST = 2
def split_text_in_lines(text, max_len, prefix="", min_indent=None):
"""
Split `text` in the biggest lines possible with the constraint of `max_len` using `prefix` on the first line and
then indenting with the same length as `prefix`.
"""
text = re.sub(r"\s+", " ", text)
indent = " " * len(prefix)
if min_indent is not None:
if len(indent) < len(min_indent):
indent = min_indent
if len(prefix) < len(min_indent):
prefix = " " * (len(min_indent) - len(prefix)) + prefix
new_lines = []
words = text.split(" ")
current_line = f"{prefix}{words[0]}"
for word in words[1:]:
try_line = f"{current_line} {word}"
if len(try_line) > max_len:
new_lines.append(current_line)
current_line = f"{indent}{word}"
else:
current_line = try_line
new_lines.append(current_line)
return "\n".join(new_lines)
def get_indent(line):
"""Get the indentation of `line`."""
indent_search = _re_indent.search(line)
return indent_search.groups()[0] if indent_search is not None else ""
class CodeStyler:
"""A generic class to style .rst files."""
def is_no_style_block(self, line):
"""Whether or not `line` introduces a block where styling should be ignore"""
if _re_code_block.search(line) is not None:
return True
if _re_textual_blocks.search(line) is not None:
return False
return _re_ignore.search(line) is not None
def is_comment_or_textual_block(self, line):
"""Whether or not `line` introduces a block where styling should not be ignored (note, warnings...)"""
if _re_comment.search(line):
return True
return _re_textual_blocks.search(line) is not None
def is_special_block(self, line):
"""Whether or not `line` introduces a special block."""
if self.is_no_style_block(line):
self.in_block = SpecialBlock.NO_STYLE
return True
return False
def init_in_block(self, text):
"""
Returns the initial value for `self.in_block`.
Useful for some docstrings beginning inside an argument declaration block (all models).
"""
return SpecialBlock.NOT_SPECIAL
def style_paragraph(self, paragraph, max_len, no_style=False, min_indent=None):
"""
Style `paragraph` (a list of lines) by making sure no line goes over `max_len`, except if the `no_style` flag
is passed.
"""
if len(paragraph) == 0:
return ""
if no_style or self.in_block == SpecialBlock.NO_STYLE:
return "\n".join(paragraph)
if _re_list.search(paragraph[0]) is not None:
# Great, we're in a list. So we need to split our paragraphs in smaller parts, one for each item.
result = ""
remainder = ""
prefix = _re_list.search(paragraph[0]).groups()[0]
prefix_indent = get_indent(paragraph[0])
current_item = [paragraph[0][len(prefix) :]]
for i, line in enumerate(paragraph[1:]):
new_item_search = _re_list.search(line)
indent = get_indent(line)
if len(indent) < len(prefix_indent) or (len(indent) == len(prefix_indent) and new_item_search is None):
# There might not be an empty line after the list, formatting the remainder recursively.
remainder = "\n" + self.style_paragraph(
paragraph[i + 1 :], max_len, no_style=no_style, min_indent=min_indent
)
break
elif new_item_search is not None:
text = " ".join([l.strip() for l in current_item])
result += split_text_in_lines(text, max_len, prefix, min_indent=min_indent) + "\n"
prefix = new_item_search.groups()[0]
prefix_indent = indent
current_item = [line[len(prefix) :]]
else:
current_item.append(line)
# Treat the last item
text = " ".join([l.strip() for l in current_item])
result += split_text_in_lines(text, max_len, prefix, min_indent=min_indent)
# Add the potential remainder
return result + remainder
if len(paragraph) > 1 and self.is_comment_or_textual_block(paragraph[0]):
# Comments/notes in rst should be restyled with indentation, ignoring the first line.
indent = get_indent(paragraph[1])
text = " ".join([l.strip() for l in paragraph[1:]])
return paragraph[0] + "\n" + split_text_in_lines(text, max_len, indent, min_indent=min_indent)
if self.in_block == SpecialBlock.ARG_LIST:
# Arg lists are special: we need to ignore the lines that are at the first indentation level beneath the
# Args/Parameters (parameter description), then we can style the indentation level beneath.
result = ""
# The args/parameters could be in that paragraph and should be ignored
if _re_arg_def.search(paragraph[0]) is not None:
if len(paragraph) == 1:
return paragraph[0]
result += paragraph[0] + "\n"
paragraph = paragraph[1:]
if self.current_indent is None:
self.current_indent = get_indent(paragraph[1])
current_item = []
for line in paragraph:
if get_indent(line) == self.current_indent:
if len(current_item) > 0:
item_indent = get_indent(current_item[0])
text = " ".join([l.strip() for l in current_item])
result += split_text_in_lines(text, max_len, item_indent, min_indent=min_indent) + "\n"
result += line + "\n"
current_item = []
else:
current_item.append(line)
if len(current_item) > 0:
item_indent = get_indent(current_item[0])
text = " ".join([l.strip() for l in current_item])
result += split_text_in_lines(text, max_len, item_indent, min_indent=min_indent) + "\n"
return result[:-1]
indent = get_indent(paragraph[0])
text = " ".join([l.strip() for l in paragraph])
return split_text_in_lines(text, max_len, indent, min_indent=min_indent)
def style(self, text, max_len=119, min_indent=None):
"""Style `text` to `max_len`."""
new_lines = []
paragraph = []
self.current_indent = ""
# If one of those is True, the paragraph should not be touched (code samples, lists...)
no_style = False
no_style_next = False
self.in_block = self.init_in_block(text)
# If this is True, we force-break a paragraph, even if there is no new empty line.
break_paragraph = False
lines = text.split("\n")
last_line = None
for line in lines:
# New paragraph
line_is_empty = len(line.strip()) == 0
list_begins = (
_re_list.search(line) is not None
and last_line is not None
and len(get_indent(line)) > len(get_indent(last_line))
)
if line_is_empty or break_paragraph or list_begins:
if len(paragraph) > 0:
if self.in_block != SpecialBlock.NOT_SPECIAL:
indent = get_indent(paragraph[0])
# Are we still in a no-style block?
if self.current_indent is None:
# If current_indent is None, we haven't begun the interior of the block so the answer is
# yes, unless we have an indent of 0 in which case the special block took one line only.
if len(indent) == 0:
self.in_block = SpecialBlock.NOT_SPECIAL
else:
self.current_indent = indent
elif not indent.startswith(self.current_indent):
# If not, we are leaving the block when we unindent.
self.in_block = SpecialBlock.NOT_SPECIAL
if self.is_special_block(paragraph[0]):
# Maybe we are starting a special block.
if len(paragraph) > 1:
# If we have the interior of the block in the paragraph, we grab the indent.
self.current_indent = get_indent(paragraph[1])
else:
# We will determine the indent with the next paragraph
self.current_indent = None
styled_paragraph = self.style_paragraph(
paragraph, max_len, no_style=no_style, min_indent=min_indent
)
new_lines.append(styled_paragraph + "\n")
else:
new_lines.append("")
paragraph = []
no_style = no_style_next
no_style_next = False
last_line = None
if (not break_paragraph and not list_begins) or line_is_empty:
break_paragraph = False
continue
break_paragraph = False
# Title and section lines should go to the max + add a new paragraph.
if (
len(set(line)) == 1
and line[0] in TITLE_SPECIAL_CHARS
and last_line is not None
and len(line) >= len(last_line)
):
line = line[0] * max_len
break_paragraph = True
# proper doc comment indicates the next paragraph should be no-style.
if _re_doc_ignore.search(line) is not None:
no_style_next = True
# Table are in just one paragraph and should be no-style.
if _re_table.search(line) is not None:
no_style = True
paragraph.append(line)
last_line = line
# Just have to treat the last paragraph. It could still be in a no-style block (or not)
if len(paragraph) > 0:
# Are we still in a special block
# (if current_indent is None, we are but no need to set it since we are the end.)
if self.in_block != SpecialBlock.NO_STYLE and self.current_indent is not None:
indent = get_indent(paragraph[0])
if not indent.startswith(self.current_indent):
self.in_block = SpecialBlock.NOT_SPECIAL
_ = self.is_special_block(paragraph[0])
new_lines.append(self.style_paragraph(paragraph, max_len, no_style=no_style, min_indent=min_indent) + "\n")
return "\n".join(new_lines)
class DocstringStyler(CodeStyler):
"""Class to style docstrings that take the main method from `CodeStyler`."""
def is_no_style_block(self, line):
if _re_example.search(line) is not None:
return True
return _re_code_block.search(line) is not None
# return super().is_no_style_block(line) is not None
def is_comment_or_textual_block(self, line):
if _re_return.search(line) is not None:
self.in_block = SpecialBlock.NOT_SPECIAL
return True
return super().is_comment_or_textual_block(line)
def is_special_block(self, line):
if self.is_no_style_block(line):
self.in_block = SpecialBlock.NO_STYLE
return True
if _re_arg_def.search(line) is not None:
self.in_block = SpecialBlock.ARG_LIST
return True
return False
def init_in_block(self, text):
lines = text.split("\n")
while len(lines) > 0 and len(lines[0]) == 0:
lines = lines[1:]
if len(lines) == 0:
return SpecialBlock.NOT_SPECIAL
if re.search(r":\s*$", lines[0]):
indent = get_indent(lines[0])
if (
len(lines) == 1
or len(get_indent(lines[1])) > len(indent)
or (len(get_indent(lines[1])) == len(indent) and re.search(r":\s*$", lines[1]))
):
self.current_indent = indent
return SpecialBlock.ARG_LIST
return SpecialBlock.NOT_SPECIAL
rst_styler = CodeStyler()
doc_styler = DocstringStyler()
def style_rst_file(doc_file, max_len=119, check_only=False):
""" Style one rst file `doc_file` to `max_len`."""
with open(doc_file, "r", encoding="utf-8") as f:
doc = f.read()
clean_doc = rst_styler.style(doc, max_len=max_len)
diff = clean_doc != doc
if not check_only and diff:
print(f"Overwriting content of {doc_file}.")
with open(doc_file, "w", encoding="utf-8") as f:
f.write(clean_doc)
return diff
def style_docstring(docstring, max_len=119):
"""Style `docstring` to `max_len`."""
# One-line docstring that are not too long are left as is.
if len(docstring) < max_len and "\n" not in docstring:
return docstring
# Grab the indent from the last line
last_line = docstring.split("\n")[-1]
# Is it empty except for the last triple-quotes (not-included in `docstring`)?
indent_search = re.search(r"^(\s*)$", last_line)
if indent_search is not None:
indent = indent_search.groups()[0]
if len(indent) > 0:
docstring = docstring[: -len(indent)]
# Or are the triple quotes next to text (we will fix that).
else:
indent_search = _re_indent.search(last_line)
indent = indent_search.groups()[0] if indent_search is not None else ""
# Add missing new lines before Args/Returns etc.
docstring = _re_any_doc_special_word.sub(r"\n\n\1\2\3\n", docstring)
# Style
styled_doc = doc_styler.style(docstring, max_len=max_len, min_indent=indent)
# Add new lines if necessary
if not styled_doc.startswith("\n"):
styled_doc = "\n" + styled_doc
if not styled_doc.endswith("\n"):
styled_doc += "\n"
return styled_doc + indent
def style_file_docstrings(code_file, max_len=119, check_only=False):
"""Style all docstrings in `code_file` to `max_len`."""
with open(code_file, "r", encoding="utf-8") as f:
code = f.read()
splits = code.split('"""')
splits = [
(s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
for i, s in enumerate(splits)
]
clean_code = '"""'.join(splits)
diff = clean_code != code
if not check_only and diff:
print(f"Overwriting content of {code_file}.")
with open(code_file, "w", encoding="utf-8") as f:
f.write(clean_code)
return diff
def style_doc_files(*files, max_len=119, check_only=False):
"""
Style all `files` to `max_len` and fixes mistakes if not `check_only`, otherwise raises an error if styling should
be done.
"""
changed = []
for file in files:
# Treat folders
if os.path.isdir(file):
files = [os.path.join(file, f) for f in os.listdir(file)]
files = [f for f in files if os.path.isdir(f) or f.endswith(".rst") or f.endswith(".py")]
changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
# Treat rst
elif file.endswith(".rst"):
if style_rst_file(file, max_len=max_len, check_only=check_only):
changed.append(file)
# Treat python files
elif file.endswith(".py"):
if style_file_docstrings(file, max_len=max_len, check_only=check_only):
changed.append(file)
else:
warnings.warn(f"Ignoring {file} because it's not a py or an rst file or a folder.")
return changed
def main(*files, max_len=119, check_only=False):
changed = style_doc_files(*files, max_len=max_len, check_only=check_only)
if check_only and len(changed) > 0:
raise ValueError(f"{len(changed)} files should be restyled!")
elif len(changed) > 0:
print(f"Cleaned {len(changed)} files!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
parser.add_argument("--max_len", type=int, help="The maximum length of lines.")
parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.")
args = parser.parse_args()
main(*args.files, max_len=args.max_len, check_only=args.check_only)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment