Unverified Commit 7732d0fe authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Upgrade black to version ~=22.0 (#15565)

* Upgrade black to version ~=22.0

* Check copies

* Fix code
parent d923f762
......@@ -544,8 +544,8 @@ class SwinEncoder(nn.Module):
[
SwinLayer(
config=config,
dim=int(config.embed_dim * 2 ** i_layer),
input_resolution=(grid_size[0] // (2 ** i_layer), grid_size[1] // (2 ** i_layer)),
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
......
......@@ -92,8 +92,8 @@ class FlaxT5DenseReluDense(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi = nn.Dense(
self.config.d_ff,
......@@ -122,8 +122,8 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi_0 = nn.Dense(
self.config.d_ff,
......@@ -194,8 +194,8 @@ class FlaxT5Attention(nn.Module):
self.inner_dim = self.n_heads * self.key_value_proj_dim
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
self.q = nn.Dense(
self.inner_dim,
......@@ -1434,7 +1434,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
sequence_output = sequence_output * (self.model_dim**-0.5)
if self.config.tie_word_embeddings:
shared_embedding = self.shared.variables["params"]["embedding"]
......@@ -1542,7 +1542,7 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.config.d_model ** -0.5)
sequence_output = sequence_output * (self.config.d_model**-0.5)
if self.config.tie_word_embeddings:
shared_embedding = module.shared.variables["params"]["embedding"]
......
......@@ -771,8 +771,8 @@ class T5PreTrainedModel(PreTrainedModel):
key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
......@@ -1639,7 +1639,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
......
......@@ -94,10 +94,10 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
mean=0, stddev=config.initializer_factor * (config.d_model**-0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5)
)
self.wi = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
......@@ -120,10 +120,10 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
mean=0, stddev=config.initializer_factor * (config.d_model**-0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5)
)
self.wi_0 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
......@@ -189,16 +189,16 @@ class TFT5Attention(tf.keras.layers.Layer):
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
)
k_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5)
)
v_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5)
)
o_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5)
)
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5)
)
self.q = tf.keras.layers.Dense(
......@@ -1472,7 +1472,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5)
sequence_output = sequence_output * (self.model_dim**-0.5)
logits = self.shared(sequence_output, mode="linear")
else:
logits = self.lm_head(sequence_output)
......
......@@ -2365,7 +2365,7 @@ def _calculate_expected_result(
# PyTorch does not currently support Huber loss with custom delta so we define it ourself
def huber_loss(input, target, delta: float = 1.0):
errors = torch.abs(input - target) # shape (batch_size,)
return torch.where(errors < delta, 0.5 * errors ** 2, errors * delta - (0.5 * delta ** 2))
return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2))
def _calculate_regression_loss(
......
......@@ -149,7 +149,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm")
self.scale = 1 / (d_head ** 0.5)
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm
......@@ -350,7 +350,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.emb_scale = d_proj**0.5
self.cutoff_ends = [0] + self.cutoffs
......@@ -362,7 +362,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
d_emb_i = d_embed // (div_val**i)
self.emb_layers.append(
TFTransfoEmbeddings(
r_idx - l_idx,
......@@ -374,7 +374,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
def build(self, input_shape):
for i in range(len(self.cutoffs)):
d_emb_i = self.d_embed // (self.div_val ** i)
d_emb_i = self.d_embed // (self.div_val**i)
self.emb_projs.append(
self.add_weight(
shape=(d_emb_i, self.d_proj),
......
......@@ -80,7 +80,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = self.d_embed // (self.div_val ** i)
d_emb_i = self.d_embed // (self.div_val**i)
weight = self.add_weight(
shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name=f"out_projs_._{i}"
......
......@@ -259,7 +259,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
self.scale = 1 / (d_head ** 0.5)
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm
......@@ -412,7 +412,7 @@ class AdaptiveEmbedding(nn.Module):
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.emb_scale = d_proj**0.5
self.cutoff_ends = [0] + self.cutoffs
......@@ -425,7 +425,7 @@ class AdaptiveEmbedding(nn.Module):
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
d_emb_i = d_embed // (div_val**i)
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
......
......@@ -60,7 +60,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
d_emb_i = d_embed // (div_val**i)
self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
......
......@@ -185,7 +185,7 @@ class TrOCRAttention(nn.Module):
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
......
......@@ -484,7 +484,7 @@ class UniSpeechAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -523,7 +523,7 @@ class UniSpeechSatAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -192,7 +192,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
......@@ -231,7 +231,7 @@ class ViTMAEEmbeddings(nn.Module):
def initialize_weights(self):
# initialize (and freeze) position embeddings by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches ** 0.5), add_cls_token=True
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
)
self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
......@@ -741,7 +741,7 @@ class ViTMAEDecoder(nn.Module):
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size)
self.decoder_pred = nn.Linear(
config.decoder_hidden_size, config.patch_size ** 2 * config.num_channels, bias=True
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
) # encoder to decoder
self.gradient_checkpointing = False
self.config = config
......@@ -750,7 +750,7 @@ class ViTMAEDecoder(nn.Module):
def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1], int(num_patches ** 0.5), add_cls_token=True
self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
......@@ -861,7 +861,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
......
......@@ -770,7 +770,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
......
......@@ -566,7 +566,7 @@ class Wav2Vec2Attention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -486,7 +486,7 @@ class WavLMAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
......
......@@ -261,7 +261,7 @@ class XGLMAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
......@@ -169,7 +169,7 @@ class XLMConfig(PretrainedConfig):
n_langs=1,
use_lang_emb=True,
max_position_embeddings=512,
embed_init_std=2048 ** -0.5,
embed_init_std=2048**-0.5,
layer_norm_eps=1e-12,
init_std=0.02,
bos_index=0,
......
......@@ -76,7 +76,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.scale = 1 / (config.d_head**0.5)
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
......
......@@ -220,7 +220,7 @@ class XLNetRelativeAttention(nn.Module):
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.scale = 1 / (config.d_head**0.5)
self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment