"...glm-4v-9b_pytorch.git" did not exist on "1bfbcff03ba1d1fc1701616cd9475416b689eeb6"
Commit f753d4e3 authored by LysandreJik's avatar LysandreJik
Browse files

Removed typings for Python 2

parent 75bc2a03
...@@ -158,8 +158,7 @@ class Embeddings(nn.Module): ...@@ -158,8 +158,7 @@ class Embeddings(nn.Module):
return embeddings return embeddings
class MultiHeadSelfAttention(nn.Module): class MultiHeadSelfAttention(nn.Module):
def __init__(self, def __init__(self, config):
config):
super(MultiHeadSelfAttention, self).__init__() super(MultiHeadSelfAttention, self).__init__()
self.n_heads = config.n_heads self.n_heads = config.n_heads
...@@ -192,12 +191,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -192,12 +191,7 @@ class MultiHeadSelfAttention(nn.Module):
self.n_heads = self.n_heads - len(heads) self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads self.dim = attention_head_size * self.n_heads
def forward(self, def forward(self, query, key, value, mask, head_mask = None):
query: torch.tensor,
key: torch.tensor,
value: torch.tensor,
mask: torch.tensor,
head_mask: torch.tensor = None):
""" """
Parameters Parameters
---------- ----------
...@@ -258,8 +252,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -258,8 +252,7 @@ class MultiHeadSelfAttention(nn.Module):
return (context,) return (context,)
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, def __init__(self, config):
config):
super(FFN, self).__init__() super(FFN, self).__init__()
self.dropout = nn.Dropout(p=config.dropout) self.dropout = nn.Dropout(p=config.dropout)
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
...@@ -267,8 +260,7 @@ class FFN(nn.Module): ...@@ -267,8 +260,7 @@ class FFN(nn.Module):
assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation) assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
self.activation = gelu if config.activation == 'gelu' else nn.ReLU() self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
def forward(self, def forward(self, input):
input: torch.tensor):
x = self.lin1(input) x = self.lin1(input)
x = self.activation(x) x = self.activation(x)
x = self.lin2(x) x = self.lin2(x)
...@@ -276,8 +268,7 @@ class FFN(nn.Module): ...@@ -276,8 +268,7 @@ class FFN(nn.Module):
return x return x
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, def __init__(self, config):
config):
super(TransformerBlock, self).__init__() super(TransformerBlock, self).__init__()
self.n_heads = config.n_heads self.n_heads = config.n_heads
...@@ -295,10 +286,7 @@ class TransformerBlock(nn.Module): ...@@ -295,10 +286,7 @@ class TransformerBlock(nn.Module):
self.ffn = FFN(config) self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, def forward(self, x, attn_mask=None, head_mask=None):
x: torch.tensor,
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
""" """
Parameters Parameters
---------- ----------
...@@ -332,8 +320,7 @@ class TransformerBlock(nn.Module): ...@@ -332,8 +320,7 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, def __init__(self, config):
config):
super(Transformer, self).__init__() super(Transformer, self).__init__()
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -342,10 +329,7 @@ class Transformer(nn.Module): ...@@ -342,10 +329,7 @@ class Transformer(nn.Module):
layer = TransformerBlock(config) layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, def forward(self, x, attn_mask=None, head_mask=None):
x: torch.tensor,
attn_mask: torch.tensor = None,
head_mask: torch.tensor = None):
""" """
Parameters Parameters
---------- ----------
...@@ -512,9 +496,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -512,9 +496,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads) self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self, def forward(self,
input_ids: torch.tensor, input_ids, attention_mask=None, head_mask=None):
attention_mask: torch.tensor = None,
head_mask: torch.tensor = None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length) attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
...@@ -597,11 +579,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -597,11 +579,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self._tie_or_clone_weights(self.vocab_projector, self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings) self.distilbert.embeddings.word_embeddings)
def forward(self, def forward(self, input_ids, attention_mask=None, masked_lm_labels=None, head_mask=None):
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None,
head_mask: torch.tensor = None):
dlbrt_output = self.distilbert(input_ids=input_ids, dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
...@@ -665,11 +643,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -665,11 +643,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
labels: torch.tensor = None,
head_mask: torch.tensor = None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
...@@ -743,12 +717,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -743,12 +717,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
start_positions: torch.tensor = None,
end_positions: torch.tensor = None,
head_mask: torch.tensor = None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment