Commit f753d4e3 authored by LysandreJik's avatar LysandreJik
Browse files

Removed typings for Python 2

parent 75bc2a03
......@@ -158,8 +158,7 @@ class Embeddings(nn.Module):
return embeddings
class MultiHeadSelfAttention(nn.Module):
def __init__(self,
config):
def __init__(self, config):
super(MultiHeadSelfAttention, self).__init__()
self.n_heads = config.n_heads
......@@ -192,12 +191,7 @@ class MultiHeadSelfAttention(nn.Module):
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
def forward(self,
query: torch.tensor,
key: torch.tensor,
value: torch.tensor,
mask: torch.tensor,
head_mask: torch.tensor = None):
def forward(self, query, key, value, mask, head_mask = None):
"""
Parameters
----------
......@@ -258,8 +252,7 @@ class MultiHeadSelfAttention(nn.Module):
return (context,)
class FFN(nn.Module):
def __init__(self,
config):
def __init__(self, config):
super(FFN, self).__init__()
self.dropout = nn.Dropout(p=config.dropout)
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
......@@ -267,8 +260,7 @@ class FFN(nn.Module):
assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
def forward(self,
input: torch.tensor):
def forward(self, input):
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)
......@@ -276,8 +268,7 @@ class FFN(nn.Module):
return x
class TransformerBlock(nn.Module):
def __init__(self,
config):
def __init__(self, config):
super(TransformerBlock, self).__init__()
self.n_heads = config.n_heads
......@@ -295,10 +286,7 @@ class TransformerBlock(nn.Module):
self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self,
x: torch.tensor,
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
def forward(self, x, attn_mask=None, head_mask=None):
"""
Parameters
----------
......@@ -332,8 +320,7 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module):
def __init__(self,
config):
def __init__(self, config):
super(Transformer, self).__init__()
self.n_layers = config.n_layers
self.output_attentions = config.output_attentions
......@@ -342,10 +329,7 @@ class Transformer(nn.Module):
layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self,
x: torch.tensor,
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
def forward(self, x, attn_mask=None, head_mask=None):
"""
Parameters
----------
......@@ -512,9 +496,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
head_mask: torch.tensor = None):
input_ids, attention_mask=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
......@@ -597,11 +579,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None,
head_mask: torch.tensor = None):
def forward(self, input_ids, attention_mask=None, masked_lm_labels=None, head_mask=None):
dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
......@@ -665,11 +643,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
labels: torch.tensor = None,
head_mask: torch.tensor = None):
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
......@@ -743,12 +717,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
start_positions: torch.tensor = None,
end_positions: torch.tensor = None,
head_mask: torch.tensor = None):
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment