Commit fa2ccbc0 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Fix E266 flake8 warning (x90).

parent 2ab78325
......@@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
......
......@@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
......
......@@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
......@@ -327,7 +327,7 @@ class Transformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
......
......@@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
......@@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
......
......@@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
#####################
### PyTorch => TF 2.0
# PyTorch => TF 2.0 #
#####################
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
......@@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
### TF 2.0 => PyTorch
# TF 2.0 => PyTorch #
#####################
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
......
......@@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
def call(self, inp, training=False):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out = self.layer_norm(inp)
core_out = self.layer_1(core_out)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection
# residual connection
output = core_out + inp
else:
##### positionwise feed-forward
# positionwise feed-forward
core_out = self.layer_1(inp)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection + layer normalization
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
......@@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
#### compute attention score
# compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
......@@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score = AC + BD
attn_score = attn_score * self.scale
#### compute attention probability
# compute attention probability
if attn_mask is not None:
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
......@@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
# compute attention vector
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head]
attn_vec_sizes = shape_list(attn_vec)
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
##### linear projection
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out, training=training)
if self.pre_lnorm:
##### residual connection
# residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
......
......@@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
if g is not None:
###### Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
......@@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# position-based key head
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# h-stream
# content-stream query head
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
......@@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# post processing
output_h = self.post_attention([h, attn_vec_h], training=training)
##### g-stream
# g-stream
# query-stream query head
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
......@@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_prob = attn_prob_h, attn_prob_g
else:
###### Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
......@@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
##### Attention mask
# Attention mask
# causal attention mask
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
......@@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
......@@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
output_g = None
##### Segment embedding
# Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
......@@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
seg_mat = None
##### Positional encoding
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training)
......
......@@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
# residual connection
output = core_out + inp
else:
##### positionwise feed-forward
# positionwise feed-forward
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
......@@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
# compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
......@@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
# compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
......@@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
# compute attention vector
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
# residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
......
......@@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
if g is not None:
###### Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
......@@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
# position-based key head
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# h-stream
# content-stream query head
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
......@@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
# post processing
output_h = self.post_attention(h, attn_vec_h)
##### g-stream
# g-stream
# query-stream query head
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
......@@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
attn_prob = attn_prob_h, attn_prob_g
else:
###### Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
......@@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
##### Attention mask
# Attention mask
# causal attention mask
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
......@@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
......@@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
output_g = None
##### Segment embedding
# Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
......@@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
seg_mat = None
##### Positional encoding
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)
......
......@@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return True
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object):
"""Distribution strategies-aware gradient accumulation utility."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment