Unverified Commit 7732d0fe authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Upgrade black to version ~=22.0 (#15565)

* Upgrade black to version ~=22.0

* Check copies

* Fix code
parent d923f762
...@@ -297,7 +297,7 @@ class BigBirdEmbeddings(nn.Module): ...@@ -297,7 +297,7 @@ class BigBirdEmbeddings(nn.Module):
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
if self.rescale_embeddings: if self.rescale_embeddings:
inputs_embeds = inputs_embeds * (self.hidden_size ** 0.5) inputs_embeds = inputs_embeds * (self.hidden_size**0.5)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
......
...@@ -220,7 +220,7 @@ class FlaxBigBirdEmbeddings(nn.Module): ...@@ -220,7 +220,7 @@ class FlaxBigBirdEmbeddings(nn.Module):
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
if self.config.rescale_embeddings: if self.config.rescale_embeddings:
inputs_embeds *= self.config.hidden_size ** 0.5 inputs_embeds *= self.config.hidden_size**0.5
# Sum all embeddings # Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds hidden_states = inputs_embeds + token_type_embeddings + position_embeds
......
...@@ -1219,7 +1219,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1219,7 +1219,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -148,7 +148,7 @@ class BlenderbotAttention(nn.Module): ...@@ -148,7 +148,7 @@ class BlenderbotAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -155,7 +155,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -155,7 +155,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
......
...@@ -146,7 +146,7 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -146,7 +146,7 @@ class BlenderbotSmallAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -154,7 +154,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -154,7 +154,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
......
...@@ -96,7 +96,7 @@ class ByT5Tokenizer(PreTrainedTokenizer): ...@@ -96,7 +96,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
self._extra_ids = extra_ids self._extra_ids = extra_ids
self._utf_vocab_size = 2 ** 8 # utf is 8 bits self._utf_vocab_size = 2**8 # utf is 8 bits
# define special tokens dict # define special tokens dict
self.special_tokens_encoder: Dict[int, str] = { self.special_tokens_encoder: Dict[int, str] = {
......
...@@ -177,7 +177,7 @@ class CLIPAttention(nn.Module): ...@@ -177,7 +177,7 @@ class CLIPAttention(nn.Module):
assert ( assert (
self.head_dim * self.num_heads == self.embed_dim self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
...@@ -348,13 +348,13 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -348,13 +348,13 @@ class CLIPPreTrainedModel(PreTrainedModel):
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
elif isinstance(module, CLIPVisionEmbeddings): elif isinstance(module, CLIPVisionEmbeddings):
factor = self.config.initializer_factor factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor) nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, CLIPAttention): elif isinstance(module, CLIPAttention):
factor = self.config.initializer_factor factor = self.config.initializer_factor
in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.q_proj.weight, std=in_proj_std) nn.init.normal_(module.q_proj.weight, std=in_proj_std)
nn.init.normal_(module.k_proj.weight, std=in_proj_std) nn.init.normal_(module.k_proj.weight, std=in_proj_std)
nn.init.normal_(module.v_proj.weight, std=in_proj_std) nn.init.normal_(module.v_proj.weight, std=in_proj_std)
...@@ -362,7 +362,7 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -362,7 +362,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
elif isinstance(module, CLIPMLP): elif isinstance(module, CLIPMLP):
factor = self.config.initializer_factor factor = self.config.initializer_factor
in_proj_std = ( in_proj_std = (
(module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
) )
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc1.weight, std=fc_std)
...@@ -370,11 +370,11 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -370,11 +370,11 @@ class CLIPPreTrainedModel(PreTrainedModel):
elif isinstance(module, CLIPModel): elif isinstance(module, CLIPModel):
nn.init.normal_( nn.init.normal_(
module.text_projection.weight, module.text_projection.weight,
std=module.text_embed_dim ** -0.5 * self.config.initializer_factor, std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
) )
nn.init.normal_( nn.init.normal_(
module.visual_projection.weight, module.visual_projection.weight,
std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
) )
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
......
...@@ -263,7 +263,7 @@ class FlaxCLIPAttention(nn.Module): ...@@ -263,7 +263,7 @@ class FlaxCLIPAttention(nn.Module):
assert ( assert (
self.head_dim * self.num_heads == self.embed_dim self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.dropout = self.config.attention_dropout self.dropout = self.config.attention_dropout
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
......
...@@ -156,7 +156,7 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer): ...@@ -156,7 +156,7 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer):
self.class_embedding = self.add_weight( self.class_embedding = self.add_weight(
shape=(self.embed_dim,), shape=(self.embed_dim,),
initializer=get_initializer(self.embed_dim ** -0.5 * factor), initializer=get_initializer(self.embed_dim**-0.5 * factor),
trainable=True, trainable=True,
name="class_embedding", name="class_embedding",
) )
...@@ -270,8 +270,8 @@ class TFCLIPAttention(tf.keras.layers.Layer): ...@@ -270,8 +270,8 @@ class TFCLIPAttention(tf.keras.layers.Layer):
) )
factor = config.initializer_factor factor = config.initializer_factor
in_proj_std = (self.embed_dim ** -0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (self.embed_dim ** -0.5) * factor out_proj_std = (self.embed_dim**-0.5) * factor
self.sqrt_att_head_size = math.sqrt(self.attention_head_size) self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
...@@ -360,7 +360,7 @@ class TFCLIPMLP(tf.keras.layers.Layer): ...@@ -360,7 +360,7 @@ class TFCLIPMLP(tf.keras.layers.Layer):
self.activation_fn = get_tf_activation(config.hidden_act) self.activation_fn = get_tf_activation(config.hidden_act)
factor = config.initializer_factor factor = config.initializer_factor
in_proj_std = (config.hidden_size ** -0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
fc_std = (2 * config.hidden_size) ** -0.5 * factor fc_std = (2 * config.hidden_size) ** -0.5 * factor
self.fc1 = tf.keras.layers.Dense( self.fc1 = tf.keras.layers.Dense(
...@@ -753,14 +753,14 @@ class TFCLIPMainLayer(tf.keras.layers.Layer): ...@@ -753,14 +753,14 @@ class TFCLIPMainLayer(tf.keras.layers.Layer):
self.visual_projection = tf.keras.layers.Dense( self.visual_projection = tf.keras.layers.Dense(
units=self.projection_dim, units=self.projection_dim,
kernel_initializer=get_initializer(vision_config.hidden_size ** -0.5 * self.config.initializer_factor), kernel_initializer=get_initializer(vision_config.hidden_size**-0.5 * self.config.initializer_factor),
use_bias=False, use_bias=False,
name="visual_projection", name="visual_projection",
) )
self.text_projection = tf.keras.layers.Dense( self.text_projection = tf.keras.layers.Dense(
units=self.projection_dim, units=self.projection_dim,
kernel_initializer=get_initializer(text_config.hidden_size ** -0.5 * self.config.initializer_factor), kernel_initializer=get_initializer(text_config.hidden_size**-0.5 * self.config.initializer_factor),
use_bias=False, use_bias=False,
name="text_projection", name="text_projection",
) )
......
...@@ -68,10 +68,10 @@ def bytes_to_unicode(): ...@@ -68,10 +68,10 @@ def bytes_to_unicode():
) )
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2 ** 8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2 ** 8 + n) cs.append(2**8 + n)
n += 1 n += 1
cs = [chr(n) for n in cs] cs = [chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
......
...@@ -488,7 +488,7 @@ class DetrAttention(nn.Module): ...@@ -488,7 +488,7 @@ class DetrAttention(nn.Module):
assert ( assert (
self.head_dim * num_heads == self.embed_dim self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -823,7 +823,7 @@ class Attention(nn.Module): ...@@ -823,7 +823,7 @@ class Attention(nn.Module):
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.encoder_decoder_attention = encoder_decoder_attention self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -278,7 +278,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -278,7 +278,7 @@ class FunnelAttentionStructure(nn.Module):
# Second type # Second type
pos = pooled_pos pos = pooled_pos
stride = 2 ** block_index stride = 2**block_index
rel_pos = self.relative_pos(pos, stride) rel_pos = self.relative_pos(pos, stride)
rel_pos = rel_pos[:, None] + zero_offset rel_pos = rel_pos[:, None] + zero_offset
...@@ -297,7 +297,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -297,7 +297,7 @@ class FunnelAttentionStructure(nn.Module):
# the previous block of the 1st real block. Since the 1st real # the previous block of the 1st real block. Since the 1st real
# block always has position 1, the position of the previous block # block always has position 1, the position of the previous block
# will be at `1 - 2 ** block_index`. # will be at `1 - 2 ** block_index`.
cls_pos = pos_id.new_tensor([-(2 ** block_index) + 1]) cls_pos = pos_id.new_tensor([-(2**block_index) + 1])
pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:] pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:]
return torch.cat([cls_pos, pooled_pos_id[::2]], 0) return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
else: else:
...@@ -454,7 +454,7 @@ class FunnelRelMultiheadAttention(nn.Module): ...@@ -454,7 +454,7 @@ class FunnelRelMultiheadAttention(nn.Module):
self.post_proj = nn.Linear(n_head * d_head, d_model) self.post_proj = nn.Linear(n_head * d_head, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
self.scale = 1.0 / (d_head ** 0.5) self.scale = 1.0 / (d_head**0.5)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None): def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
"""Relative attention score for the positional encodings""" """Relative attention score for the positional encodings"""
......
...@@ -231,7 +231,7 @@ class TFFunnelAttentionStructure: ...@@ -231,7 +231,7 @@ class TFFunnelAttentionStructure:
# Second type # Second type
pos = pooled_pos pos = pooled_pos
stride = 2 ** block_index stride = 2**block_index
rel_pos = self.relative_pos(pos, stride) rel_pos = self.relative_pos(pos, stride)
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
...@@ -252,7 +252,7 @@ class TFFunnelAttentionStructure: ...@@ -252,7 +252,7 @@ class TFFunnelAttentionStructure:
# the previous block of the 1st real block. Since the 1st real # the previous block of the 1st real block. Since the 1st real
# block always has position 1, the position of the previous block # block always has position 1, the position of the previous block
# will be at `1 - 2 ** block_index`. # will be at `1 - 2 ** block_index`.
cls_pos = tf.constant([-(2 ** block_index) + 1], dtype=pos_id.dtype) cls_pos = tf.constant([-(2**block_index) + 1], dtype=pos_id.dtype)
pooled_pos_id = pos_id[1:-1] if self.truncate_seq else pos_id[1:] pooled_pos_id = pos_id[1:-1] if self.truncate_seq else pos_id[1:]
return tf.concat([cls_pos, pooled_pos_id[::2]], 0) return tf.concat([cls_pos, pooled_pos_id[::2]], 0)
else: else:
...@@ -400,7 +400,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -400,7 +400,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
self.post_proj = tf.keras.layers.Dense(d_model, kernel_initializer=initializer, name="post_proj") self.post_proj = tf.keras.layers.Dense(d_model, kernel_initializer=initializer, name="post_proj")
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.scale = 1.0 / (d_head ** 0.5) self.scale = 1.0 / (d_head**0.5)
def build(self, input_shape): def build(self, input_shape):
n_head, d_head, d_model = self.n_head, self.d_head, self.d_model n_head, d_head, d_model = self.n_head, self.d_head, self.d_model
......
...@@ -78,10 +78,10 @@ def bytes_to_unicode(): ...@@ -78,10 +78,10 @@ def bytes_to_unicode():
) )
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2 ** 8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2 ** 8 + n) cs.append(2**8 + n)
n += 1 n += 1
cs = [chr(n) for n in cs] cs = [chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
......
...@@ -418,7 +418,7 @@ class HubertAttention(nn.Module): ...@@ -418,7 +418,7 @@ class HubertAttention(nn.Module):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......
...@@ -741,7 +741,7 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -741,7 +741,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads})."
) )
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
......
...@@ -327,16 +327,16 @@ class IntGELU(nn.Module): ...@@ -327,16 +327,16 @@ class IntGELU(nn.Module):
def int_erf(self, x_int, scaling_factor): def int_erf(self, x_int, scaling_factor):
b_int = torch.floor(self.coeff[1] / scaling_factor) b_int = torch.floor(self.coeff[1] / scaling_factor)
c_int = torch.floor(self.coeff[2] / scaling_factor ** 2) c_int = torch.floor(self.coeff[2] / scaling_factor**2)
sign = torch.sign(x_int) sign = torch.sign(x_int)
abs_int = torch.min(torch.abs(x_int), -b_int) abs_int = torch.min(torch.abs(x_int), -b_int)
y_int = sign * ((abs_int + b_int) ** 2 + c_int) y_int = sign * ((abs_int + b_int) ** 2 + c_int)
scaling_factor = scaling_factor ** 2 * self.coeff[0] scaling_factor = scaling_factor**2 * self.coeff[0]
# avoid overflow # avoid overflow
y_int = floor_ste.apply(y_int / 2 ** self.const) y_int = floor_ste.apply(y_int / 2**self.const)
scaling_factor = scaling_factor * 2 ** self.const scaling_factor = scaling_factor * 2**self.const
return y_int, scaling_factor return y_int, scaling_factor
...@@ -388,9 +388,9 @@ class IntSoftmax(nn.Module): ...@@ -388,9 +388,9 @@ class IntSoftmax(nn.Module):
def int_polynomial(self, x_int, scaling_factor): def int_polynomial(self, x_int, scaling_factor):
with torch.no_grad(): with torch.no_grad():
b_int = torch.floor(self.coef[1] / scaling_factor) b_int = torch.floor(self.coef[1] / scaling_factor)
c_int = torch.floor(self.coef[2] / scaling_factor ** 2) c_int = torch.floor(self.coef[2] / scaling_factor**2)
z = (x_int + b_int) * x_int + c_int z = (x_int + b_int) * x_int + c_int
scaling_factor = self.coef[0] * scaling_factor ** 2 scaling_factor = self.coef[0] * scaling_factor**2
return z, scaling_factor return z, scaling_factor
def int_exp(self, x_int, scaling_factor): def int_exp(self, x_int, scaling_factor):
...@@ -402,7 +402,7 @@ class IntSoftmax(nn.Module): ...@@ -402,7 +402,7 @@ class IntSoftmax(nn.Module):
r = x_int - x0_int * q r = x_int - x0_int * q
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor) exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0) exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
scaling_factor = exp_scaling_factor / 2 ** self.const scaling_factor = exp_scaling_factor / 2**self.const
return exp_int, scaling_factor return exp_int, scaling_factor
def forward(self, x, scaling_factor): def forward(self, x, scaling_factor):
...@@ -420,9 +420,9 @@ class IntSoftmax(nn.Module): ...@@ -420,9 +420,9 @@ class IntSoftmax(nn.Module):
exp_int = exp / exp_scaling_factor exp_int = exp / exp_scaling_factor
exp_int_sum = exp_int.sum(dim=-1, keepdim=True) exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
factor = floor_ste.apply(2 ** self.max_bit / exp_int_sum) factor = floor_ste.apply(2**self.max_bit / exp_int_sum)
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit)) exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
scaling_factor = 1 / 2 ** self.output_bit scaling_factor = 1 / 2**self.output_bit
return exp_int * scaling_factor, scaling_factor return exp_int * scaling_factor, scaling_factor
...@@ -460,9 +460,9 @@ class IntLayerNorm(nn.Module): ...@@ -460,9 +460,9 @@ class IntLayerNorm(nn.Module):
def set_shift(self, y_int): def set_shift(self, y_int):
with torch.no_grad(): with torch.no_grad():
y_sq_int = y_int ** 2 y_sq_int = y_int**2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
shift = (torch.log2(torch.sqrt(var_int / 2 ** self.max_bit)).ceil()).max() shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max()
shift_old = self.shift shift_old = self.shift
self.shift = torch.max(self.shift, shift) self.shift = torch.max(self.shift, shift)
logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}") logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}")
...@@ -473,8 +473,8 @@ class IntLayerNorm(nn.Module): ...@@ -473,8 +473,8 @@ class IntLayerNorm(nn.Module):
to avoid overflow in the subsequent runs. to avoid overflow in the subsequent runs.
""" """
self.set_shift(y_int) # adjusts `self.shift` self.set_shift(y_int) # adjusts `self.shift`
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift) y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
y_sq_int = y_int_shifted ** 2 y_sq_int = y_int_shifted**2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
return var_int return var_int
...@@ -482,7 +482,7 @@ class IntLayerNorm(nn.Module): ...@@ -482,7 +482,7 @@ class IntLayerNorm(nn.Module):
if not self.quant_mode: if not self.quant_mode:
mean = x.mean(axis=2, keepdim=True) mean = x.mean(axis=2, keepdim=True)
y = x - mean y = x - mean
var = torch.mean(y ** 2, axis=2, keepdim=True) var = torch.mean(y**2, axis=2, keepdim=True)
x = y / torch.sqrt(self.eps + var) x = y / torch.sqrt(self.eps + var)
x = x * self.weight + self.bias x = x * self.weight + self.bias
return x, None return x, None
...@@ -496,25 +496,25 @@ class IntLayerNorm(nn.Module): ...@@ -496,25 +496,25 @@ class IntLayerNorm(nn.Module):
x_int = x / scaling_factor x_int = x / scaling_factor
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True)) mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
y_int = x_int - mean_int y_int = x_int - mean_int
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift) y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
y_sq_int = y_int_shifted ** 2 y_sq_int = y_int_shifted**2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True) var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
# overflow handling in training time # overflow handling in training time
if self.training: if self.training:
# if overflow is detected # if overflow is detected
if var_int.max() >= 2 ** self.max_bit: if var_int.max() >= 2**self.max_bit:
var_int = self.overflow_fallback(y_int) var_int = self.overflow_fallback(y_int)
assert var_int.max() < 2 ** self.max_bit + 0.1, ( assert var_int.max() < 2**self.max_bit + 0.1, (
"Error detected in overflow handling: " "Error detected in overflow handling: "
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)" "`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
) )
# To be replaced with integer-sqrt kernel that produces the same output # To be replaced with integer-sqrt kernel that produces the same output
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2 ** self.shift std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift
factor = floor_ste.apply(2 ** 31 / std_int) factor = floor_ste.apply(2**31 / std_int)
y_int = floor_ste.apply(y_int * factor / 2) y_int = floor_ste.apply(y_int * factor / 2)
scaling_factor = self.dim_sqrt / 2 ** 30 scaling_factor = self.dim_sqrt / 2**30
# scaling and shifting # scaling and shifting
bias = self.bias.data.detach() / (self.weight.data.detach()) bias = self.bias.data.detach() / (self.weight.data.detach())
...@@ -725,7 +725,7 @@ def batch_frexp(inputs, max_bit=31): ...@@ -725,7 +725,7 @@ def batch_frexp(inputs, max_bit=31):
tmp_m = [] tmp_m = []
for m in output_m: for m in output_m:
int_m_shifted = int( int_m_shifted = int(
decimal.Decimal(m * (2 ** max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP) decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
) )
tmp_m.append(int_m_shifted) tmp_m.append(int_m_shifted)
output_m = np.array(tmp_m) output_m = np.array(tmp_m)
...@@ -796,7 +796,7 @@ class FixedPointMul(Function): ...@@ -796,7 +796,7 @@ class FixedPointMul(Function):
m, e = batch_frexp(new_scale) m, e = batch_frexp(new_scale)
output = z_int.type(torch.double) * m.type(torch.double) output = z_int.type(torch.double) * m.type(torch.double)
output = torch.round(output / (2.0 ** e)) output = torch.round(output / (2.0**e))
if identity is not None: if identity is not None:
# needs addition of identity activation # needs addition of identity activation
...@@ -809,7 +809,7 @@ class FixedPointMul(Function): ...@@ -809,7 +809,7 @@ class FixedPointMul(Function):
m1, e1 = batch_frexp(new_scale) m1, e1 = batch_frexp(new_scale)
output1 = wx_int.type(torch.double) * m1.type(torch.double) output1 = wx_int.type(torch.double) * m1.type(torch.double)
output1 = torch.round(output1 / (2.0 ** e1)) output1 = torch.round(output1 / (2.0**e1))
output = output1 + output output = output1 + output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment