Unverified Commit e68ec18c authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Docs: formatting nits (#32247)



* doc formatting nits

* ignore non-autodocs

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/esm/modeling_esm.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/esm/modeling_esm.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make fixup

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 2fbbcf50
......@@ -647,8 +647,9 @@ class YolosModel(YolosPreTrainedModel):
Prunes heads of the model.
Args:
heads_to_prune (`dict` of {layer_num: list of heads to prune in this layer}):
See base class `PreTrainedModel`.
heads_to_prune (`dict`):
See base class `PreTrainedModel`. The input dictionary must have the following format: {layer_num:
list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
......
......@@ -218,7 +218,7 @@ def infer_framework_load_model(
If both frameworks are installed and available for `model`, PyTorch is selected.
Args:
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel`]):
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
config ([`AutoConfig`]):
The config associated with the model to help using the correct class
......@@ -322,7 +322,7 @@ def infer_framework_from_model(
If both frameworks are installed and available for `model`, PyTorch is selected.
Args:
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel`]):
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
model_classes (dictionary `str` to `type`, *optional*):
A mapping framework to class.
......@@ -349,7 +349,7 @@ def get_framework(model, revision: Optional[str] = None):
Select framework (TensorFlow or PyTorch) to use.
Args:
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel`]):
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch.
"""
......@@ -385,7 +385,7 @@ def get_default_model_and_revision(
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
Args:
targeted_task (`Dict` ):
targeted_task (`Dict`):
Dictionary representing the given task, that should contain default models
framework (`str`, None)
......
......@@ -22,7 +22,7 @@ logger = logging.get_logger(__name__)
@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=True),
r"""
top_k (`int`, defaults to 5):
top_k (`int`, *optional*, defaults to 5):
The number of predictions to return.
targets (`str` or `List[str]`, *optional*):
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
......
......@@ -31,7 +31,7 @@ class PipelineIterator(IterableDataset):
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
loader (`torch.utils.data.DataLoader` or `Iterable`):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
......@@ -163,7 +163,7 @@ class PipelineChunkIterator(PipelineIterator):
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
loader (`torch.utils.data.DataLoader` or `Iterable`):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
......@@ -224,7 +224,7 @@ class PipelinePackIterator(PipelineIterator):
```
Arguments:
loader (`torch.utils.data.DataLoader` or any iterator):
loader (`torch.utils.data.DataLoader` or `Iterable`):
The iterator that will be used to apply `infer` on.
infer (any function):
The function to apply of each element of `loader`.
......
......@@ -3200,7 +3200,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
</Tip>
Args:
text (`str`, `List[str]` or `List[int]` (the latter only for not-fast tokenizers)):
text (`str`, `List[str]` or (for non-fast tokenizers) `List[int]`):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
`tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method).
......
......@@ -745,7 +745,7 @@ class Trainer:
Add a callback to the current list of [`~transformers.TrainerCallback`].
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
callback (`type` or [`~transformers.TrainerCallback]`):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will instantiate a member of that class.
"""
......@@ -758,7 +758,7 @@ class Trainer:
If the callback is not found, returns `None` (and no error is raised).
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
callback (`type` or [`~transformers.TrainerCallback]`):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will pop the first member of that class found in the list of callbacks.
......@@ -772,7 +772,7 @@ class Trainer:
Remove a callback from the current list of [`~transformers.TrainerCallback`].
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
callback (`type` or [`~transformers.TrainerCallback]`):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will remove the first member of that class found in the list of callbacks.
"""
......
......@@ -80,7 +80,7 @@ class Seq2SeqTrainer(Trainer):
Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
Args:
gen_config_arg (`str` or [`~generation.GenerationConfig`]):
gen_config_arg (`str` or [`~generation.GenerationConfig]`):
`Seq2SeqTrainingArguments.generation_config` argument.
Returns:
......
......@@ -1605,7 +1605,7 @@ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
Args:
path (`str`): The path to the source file
file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
file (`str`, *optional*): The file to join with the path. Defaults to "__init__.py".
Returns:
`ModuleType`: The resulting imported module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment