"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6e603cb7892b49a2cbbc10ba859759f92c3fb7a6"
Unverified Commit f0fd73a2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Document check copies (#25291)

* Document check copies better and add tests

* Include header in check for copies

* Manual fixes

* Try autofix

* Fixes

* Clean tests

* Finalize doc

* Remove debug print

* More fixes
parent 29f04002
...@@ -101,7 +101,7 @@ own regarding how code should be written :-) ...@@ -101,7 +101,7 @@ own regarding how code should be written :-)
1. The forward pass of your model should be fully written in the modeling file while being fully independent of other 1. The forward pass of your model should be fully written in the modeling file while being fully independent of other
models in the library. If you want to reuse a block from another model, copy the code and paste it with a models in the library. If you want to reuse a block from another model, copy the code and paste it with a
`# Copied from` comment on top (see [here](https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/roberta/modeling_roberta.py#L160) `# Copied from` comment on top (see [here](https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/roberta/modeling_roberta.py#L160)
for a good example). for a good example and [there](pr_checks#check-copies) for more documentation on Copied from).
2. The code should be fully understandable, even by a non-native English speaker. This means you should pick 2. The code should be fully understandable, even by a non-native English speaker. This means you should pick
descriptive variable names and avoid abbreviations. As an example, `activation` is preferred to `act`. descriptive variable names and avoid abbreviations. As an example, `activation` is preferred to `act`.
One-letter variable names are strongly discouraged unless it's an index in a for loop. One-letter variable names are strongly discouraged unless it's an index in a for loop.
......
...@@ -142,3 +142,58 @@ Additional checks concern PRs that add new models, mainly that: ...@@ -142,3 +142,58 @@ Additional checks concern PRs that add new models, mainly that:
- All checkpoints used actually exist on the Hub - All checkpoints used actually exist on the Hub
--> -->
### Check copies
Since the Transformers library is very opinionated with respect to model code, and each model should fully be implemented in a single file without relying on other models, we have added a mechanism that checks whether a copy of the code of a layer of a given model stays consistent with the original. This way, when there is a bug fix, we can see all other impacted models and choose to trickle down the modification or break the copy.
<Tip>
If a file is a full copy of another file, you should register it in the constant `FULL_COPIES` of `utils/check_copies.py`.
</Tip>
This mechanism relies on comments of the form `# Copied from xxx`. The `xxx` should contain the whole path to the class of function which is being copied below. For instance, `RobertaSelfOutput` is a direct copy of the `BertSelfOutput` class, so you can see [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L289) it has a comment:
```py
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
```
Note that instead of applying this to a whole class, you can apply it to the relevant methods that are copied from. For instance [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L598) you can see how `RobertaPreTrainedModel._init_weights` is copied from the same method in `BertPreTrainedModel` with the comment:
```py
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
```
Sometimes the copy is exactly the same except for names: for instance in `RobertaAttention`, we use `RobertaSelfAttention` insted of `BertSelfAttention` but other than that, the code is exactly the same. This is why `# Copied from` supports simple string replacements with the follwoing syntax: `Copied from xxx with foo->bar`. This means the code is copied with all instances of `foo` being replaced by `bar`. You can see how it used [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L304C1-L304C86) in `RobertaAttention` with the comment:
```py
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
```
Note that there shouldn't be any spaces around the arrow (unless that space is part of the pattern to replace of course).
You can add several patterns separated by a comma. For instance here `CamemberForMaskedLM` is a direct copy of `RobertaForMaskedLM` with two replacements: `Roberta` to `Camembert` and `ROBERTA` to `CAMEMBERT`. You can see [here](https://github.com/huggingface/transformers/blob/15082a9dc6950ecae63a0d3e5060b2fc7f15050a/src/transformers/models/camembert/modeling_camembert.py#L929) this is done with the comment:
```py
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
```
If the order matters (because one of the replacements might conflict with a previous one), the replacements are executed from left to right.
<Tip>
If the replacements change the formatting (if you replace a short name by a very long name for instance), the copy is checked after applying the auto-formatter.
</Tip>
Another way when the patterns are just different casings of the same replacement (with an uppercased and a lowercased variants) is just to add the option `all-casing`. [Here](https://github.com/huggingface/transformers/blob/15082a9dc6950ecae63a0d3e5060b2fc7f15050a/src/transformers/models/mobilebert/modeling_mobilebert.py#L1237) is an example in `MobileBertForSequenceClassification` with the comment:
```py
# Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
```
In this case, the code is copied from `BertForSequenceClassification` by replacing:
- `Bert` by `MobileBert` (for instance when using `MobileBertModel` in the init)
- `bert` by `mobilebert` (for instance when defining `self.mobilebert`)
- `BERT` by `MOBILEBERT` (in the constant `MOBILEBERT_INPUTS_DOCSTRING`)
...@@ -1168,6 +1168,7 @@ else: ...@@ -1168,6 +1168,7 @@ else:
"BartForSequenceClassification", "BartForSequenceClassification",
"BartModel", "BartModel",
"BartPretrainedModel", "BartPretrainedModel",
"BartPreTrainedModel",
"PretrainedBartModel", "PretrainedBartModel",
] ]
) )
...@@ -5072,6 +5073,7 @@ if TYPE_CHECKING: ...@@ -5072,6 +5073,7 @@ if TYPE_CHECKING:
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartPreTrainedModel,
BartPretrainedModel, BartPretrainedModel,
PretrainedBartModel, PretrainedBartModel,
) )
......
...@@ -173,7 +173,6 @@ class FlaxAlbertEmbeddings(nn.Module): ...@@ -173,7 +173,6 @@ class FlaxAlbertEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__
def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True): def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
# Embed # Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4")) inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
......
...@@ -183,10 +183,10 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -183,10 +183,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
self.sp_model.Load(vocab_file) self.sp_model.Load(vocab_file)
@property @property
def vocab_size(self): def vocab_size(self) -> int:
return len(self.sp_model) return len(self.sp_model)
def get_vocab(self): def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
return vocab return vocab
......
...@@ -286,7 +286,7 @@ def align_loss(similarity: torch.Tensor) -> torch.Tensor: ...@@ -286,7 +286,7 @@ def align_loss(similarity: torch.Tensor) -> torch.Tensor:
return (caption_loss + image_loss) / 2.0 return (caption_loss + image_loss) / 2.0
# Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet -> AlignVision # Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision
def round_filters(config: AlignVisionConfig, num_channels: int): def round_filters(config: AlignVisionConfig, num_channels: int):
r""" r"""
Round number of filters based on depth multiplier. Round number of filters based on depth multiplier.
......
...@@ -49,6 +49,7 @@ else: ...@@ -49,6 +49,7 @@ else:
"BartForQuestionAnswering", "BartForQuestionAnswering",
"BartForSequenceClassification", "BartForSequenceClassification",
"BartModel", "BartModel",
"BartPreTrainedModel",
"BartPretrainedModel", "BartPretrainedModel",
"PretrainedBartModel", "PretrainedBartModel",
] ]
...@@ -107,6 +108,7 @@ if TYPE_CHECKING: ...@@ -107,6 +108,7 @@ if TYPE_CHECKING:
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartPreTrainedModel,
BartPretrainedModel, BartPretrainedModel,
PretrainedBartModel, PretrainedBartModel,
) )
......
...@@ -502,7 +502,7 @@ class BartClassificationHead(nn.Module): ...@@ -502,7 +502,7 @@ class BartClassificationHead(nn.Module):
return hidden_states return hidden_states
class BartPretrainedModel(PreTrainedModel): class BartPreTrainedModel(PreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
...@@ -536,10 +536,18 @@ class BartPretrainedModel(PreTrainedModel): ...@@ -536,10 +536,18 @@ class BartPretrainedModel(PreTrainedModel):
return dummy_inputs return dummy_inputs
class PretrainedBartModel(BartPretrainedModel): class PretrainedBartModel(BartPreTrainedModel):
def __init_subclass__(self): def __init_subclass__(self):
warnings.warn( warnings.warn(
"The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
FutureWarning,
)
class BartPretrainedModel(BartPreTrainedModel):
def __init_subclass__(self):
warnings.warn(
"The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
FutureWarning, FutureWarning,
) )
...@@ -700,7 +708,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -700,7 +708,7 @@ BART_INPUTS_DOCSTRING = r"""
""" """
class BartEncoder(BartPretrainedModel): class BartEncoder(BartPreTrainedModel):
""" """
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`BartEncoderLayer`]. [`BartEncoderLayer`].
...@@ -882,7 +890,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -882,7 +890,7 @@ class BartEncoder(BartPretrainedModel):
) )
class BartDecoder(BartPretrainedModel): class BartDecoder(BartPreTrainedModel):
""" """
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
...@@ -1169,7 +1177,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1169,7 +1177,7 @@ class BartDecoder(BartPretrainedModel):
"The bare BART Model outputting raw hidden-states without any specific head on top.", "The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartModel(BartPretrainedModel): class BartModel(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
...@@ -1296,7 +1304,7 @@ class BartModel(BartPretrainedModel): ...@@ -1296,7 +1304,7 @@ class BartModel(BartPretrainedModel):
@add_start_docstrings( @add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
) )
class BartForConditionalGeneration(BartPretrainedModel): class BartForConditionalGeneration(BartPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"] _keys_to_ignore_on_load_missing = ["final_logits_bias"]
...@@ -1471,7 +1479,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1471,7 +1479,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForSequenceClassification(BartPretrainedModel): class BartForSequenceClassification(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
...@@ -1601,7 +1609,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1601,7 +1609,7 @@ class BartForSequenceClassification(BartPretrainedModel):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForQuestionAnswering(BartPretrainedModel): class BartForQuestionAnswering(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1719,7 +1727,7 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1719,7 +1727,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
) )
class BartDecoderWrapper(BartPretrainedModel): class BartDecoderWrapper(BartPreTrainedModel):
""" """
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework. used in combination with the [`EncoderDecoderModel`] framework.
...@@ -1739,7 +1747,7 @@ class BartDecoderWrapper(BartPretrainedModel): ...@@ -1739,7 +1747,7 @@ class BartDecoderWrapper(BartPretrainedModel):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForCausalLM(BartPretrainedModel): class BartForCausalLM(BartPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -300,7 +300,7 @@ class BitEmbeddings(nn.Module): ...@@ -300,7 +300,7 @@ class BitEmbeddings(nn.Module):
# Copied from transformers.models.convnext.modeling_convnext.drop_path # Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -22,7 +22,6 @@ from typing import Callable, Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Callable, Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
...@@ -205,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r""" ...@@ -205,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r""" ...@@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -160,7 +160,7 @@ class CLIPSegImageSegmentationOutput(ModelOutput): ...@@ -160,7 +160,7 @@ class CLIPSegImageSegmentationOutput(ModelOutput):
class CLIPSegVisionEmbeddings(nn.Module): class CLIPSegVisionEmbeddings(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
def __init__(self, config: CLIPSegVisionConfig): def __init__(self, config: CLIPSegVisionConfig):
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -861,7 +861,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -861,7 +861,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
return target return target
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare
def prepare(self, image, target, return_segmentation_masks=False, masks_path=None): def prepare(self, image, target, return_segmentation_masks=None, masks_path=None):
logger.warning_once( logger.warning_once(
"The `prepare` method is deprecated and will be removed in a v4.33. " "The `prepare` method is deprecated and will be removed in a v4.33. "
"Please use `prepare_annotation` instead. Note: the `prepare_annotation` method " "Please use `prepare_annotation` instead. Note: the `prepare_annotation` method "
......
...@@ -61,7 +61,7 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,7 +61,7 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -61,7 +61,7 @@ CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,7 +61,7 @@ CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -78,7 +78,7 @@ class BaseModelOutputWithCLSToken(ModelOutput): ...@@ -78,7 +78,7 @@ class BaseModelOutputWithCLSToken(ModelOutput):
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -54,7 +54,7 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -54,7 +54,7 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.convnext.modeling_convnext.drop_path # Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -269,7 +269,7 @@ class DinatDownsampler(nn.Module): ...@@ -269,7 +269,7 @@ class DinatDownsampler(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
...@@ -316,7 +316,7 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals ...@@ -316,7 +316,7 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals
# Copied from transformers.models.beit.modeling_beit.BeitDropPath # Copied from transformers.models.beit.modeling_beit.BeitDropPath
class Dinov2DropPath: class Dinov2DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None: def __init__(self, drop_prob: Optional[float] = None) -> None:
......
...@@ -295,8 +295,8 @@ class DonutSwinPatchMerging(nn.Module): ...@@ -295,8 +295,8 @@ class DonutSwinPatchMerging(nn.Module):
return input_feature return input_feature
# Copied from transformers.models.swin.modeling_swin.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment