Unverified Commit 5bdf3313 authored by John Ryan's avatar John Ryan Committed by GitHub
Browse files

Adding type hints for Distilbert (#16090)



* Distillbert type - squash

* Update src/transformers/models/distilbert/modeling_distilbert.py

Undo cleanup
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/distilbert/modeling_distilbert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove type
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 0b8b0618
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import math import math
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -26,6 +27,8 @@ from packaging import version ...@@ -26,6 +27,8 @@ from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.configuration_utils import PretrainedConfig
from ...activations import get_activation from ...activations import get_activation
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
from ...file_utils import ( from ...file_utils import (
...@@ -72,7 +75,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -72,7 +75,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed import deepspeed
...@@ -83,7 +86,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out): ...@@ -83,7 +86,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
def _create_sinusoidal_embeddings(n_pos, dim, out): def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
...@@ -92,7 +95,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out): ...@@ -92,7 +95,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
class Embeddings(nn.Module): class Embeddings(nn.Module):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
...@@ -108,7 +111,7 @@ class Embeddings(nn.Module): ...@@ -108,7 +111,7 @@ class Embeddings(nn.Module):
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
) )
def forward(self, input_ids): def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
""" """
Parameters: Parameters:
input_ids: torch.tensor(bs, max_seq_length) The token ids to embed. input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
...@@ -137,7 +140,7 @@ class Embeddings(nn.Module): ...@@ -137,7 +140,7 @@ class Embeddings(nn.Module):
class MultiHeadSelfAttention(nn.Module): class MultiHeadSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.n_heads = config.n_heads self.n_heads = config.n_heads
...@@ -151,9 +154,9 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -151,9 +154,9 @@ class MultiHeadSelfAttention(nn.Module):
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads = set() self.pruned_heads: Set[int] = set()
def prune_heads(self, heads): def prune_heads(self, heads: List[int]):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
if len(heads) == 0: if len(heads) == 0:
return return
...@@ -168,7 +171,15 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -168,7 +171,15 @@ class MultiHeadSelfAttention(nn.Module):
self.dim = attention_head_size * self.n_heads self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask=None, output_attentions=False): def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
""" """
Parameters: Parameters:
query: torch.tensor(bs, seq_length, dim) query: torch.tensor(bs, seq_length, dim)
...@@ -189,11 +200,11 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -189,11 +200,11 @@ class MultiHeadSelfAttention(nn.Module):
mask_reshp = (bs, 1, 1, k_length) mask_reshp = (bs, 1, 1, k_length)
def shape(x): def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads""" """separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x): def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads""" """group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
...@@ -224,7 +235,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -224,7 +235,7 @@ class MultiHeadSelfAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.dropout = nn.Dropout(p=config.dropout) self.dropout = nn.Dropout(p=config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
...@@ -233,10 +244,10 @@ class FFN(nn.Module): ...@@ -233,10 +244,10 @@ class FFN(nn.Module):
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
self.activation = get_activation(config.activation) self.activation = get_activation(config.activation)
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
def ff_chunk(self, input): def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
x = self.lin1(input) x = self.lin1(input)
x = self.activation(x) x = self.activation(x)
x = self.lin2(x) x = self.lin2(x)
...@@ -245,7 +256,7 @@ class FFN(nn.Module): ...@@ -245,7 +256,7 @@ class FFN(nn.Module):
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
assert config.dim % config.n_heads == 0 assert config.dim % config.n_heads == 0
...@@ -256,7 +267,13 @@ class TransformerBlock(nn.Module): ...@@ -256,7 +267,13 @@ class TransformerBlock(nn.Module):
self.ffn = FFN(config) self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
""" """
Parameters: Parameters:
x: torch.tensor(bs, seq_length, dim) x: torch.tensor(bs, seq_length, dim)
...@@ -284,7 +301,7 @@ class TransformerBlock(nn.Module): ...@@ -284,7 +301,7 @@ class TransformerBlock(nn.Module):
# Feed Forward Network # Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,) output = (ffn_output,)
if output_attentions: if output_attentions:
...@@ -293,14 +310,20 @@ class TransformerBlock(nn.Module): ...@@ -293,14 +310,20 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
def forward( def forward(
self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None self,
): # docstyle-ignore x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
""" """
Parameters: Parameters:
x: torch.tensor(bs, seq_length, dim) Input sequence embedded. x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
...@@ -357,7 +380,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -357,7 +380,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None load_tf_weights = None
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
def _init_weights(self, module): def _init_weights(self, module: nn.Module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
...@@ -432,7 +455,7 @@ DISTILBERT_INPUTS_DOCSTRING = r""" ...@@ -432,7 +455,7 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertModel(DistilBertPreTrainedModel): class DistilBertModel(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.embeddings = Embeddings(config) # Embeddings self.embeddings = Embeddings(config) # Embeddings
...@@ -489,13 +512,13 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -489,13 +512,13 @@ class DistilBertModel(DistilBertPreTrainedModel):
# move position_embeddings to correct device # move position_embeddings to correct device
self.embeddings.position_embeddings.to(self.device) self.embeddings.position_embeddings.to(self.device)
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings: nn.Embedding):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
""" """
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel class PreTrainedModel
...@@ -512,14 +535,14 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -512,14 +535,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -560,7 +583,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -560,7 +583,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.activation = get_activation(config.activation) self.activation = get_activation(config.activation)
...@@ -595,10 +618,10 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -595,10 +618,10 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
""" """
self.distilbert.resize_position_embeddings(new_num_position_embeddings) self.distilbert.resize_position_embeddings(new_num_position_embeddings)
def get_output_embeddings(self): def get_output_embeddings(self) -> nn.Module:
return self.vocab_projector return self.vocab_projector
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings: nn.Module):
self.vocab_projector = new_embeddings self.vocab_projector = new_embeddings
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
...@@ -610,15 +633,15 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -610,15 +633,15 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -666,7 +689,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -666,7 +689,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForSequenceClassification(DistilBertPreTrainedModel): class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config self.config = config
...@@ -708,15 +731,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -708,15 +731,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -784,7 +807,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -784,7 +807,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
...@@ -824,16 +847,16 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -824,16 +847,16 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
start_positions=None, start_positions: Optional[torch.Tensor] = None,
end_positions=None, end_positions: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -901,7 +924,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -901,7 +924,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForTokenClassification(DistilBertPreTrainedModel): class DistilBertForTokenClassification(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -941,15 +964,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -941,15 +964,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -996,7 +1019,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -996,7 +1019,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForMultipleChoice(DistilBertPreTrainedModel): class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
...@@ -1033,15 +1056,15 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel): ...@@ -1033,15 +1056,15 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor, ...]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment