Unverified Commit beb2a096 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect...

DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect initialization (#28760)
parent f7076cd3
...@@ -255,7 +255,7 @@ class ClvpRotaryPositionalEmbedding(nn.Module): ...@@ -255,7 +255,7 @@ class ClvpRotaryPositionalEmbedding(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
dim = max(config.projection_dim // (config.num_attention_heads * 2), 32) dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None self.cached_sequence_length = None
......
...@@ -53,8 +53,8 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -53,8 +53,8 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions # Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
......
...@@ -443,7 +443,7 @@ class ConditionalDetrSinePositionEmbedding(nn.Module): ...@@ -443,7 +443,7 @@ class ConditionalDetrSinePositionEmbedding(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
......
...@@ -47,8 +47,8 @@ def angle_defn(pos, i, d_model_size): ...@@ -47,8 +47,8 @@ def angle_defn(pos, i, d_model_size):
def positional_encoding(position, d_model_size, dtype): def positional_encoding(position, d_model_size, dtype):
# create the sinusoidal pattern for the positional encoding # create the sinusoidal pattern for the positional encoding
angle_rads = angle_defn( angle_rads = angle_defn(
torch.arange(position, dtype=dtype).unsqueeze(1), torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=dtype).unsqueeze(0), torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
d_model_size, d_model_size,
) )
......
...@@ -491,7 +491,7 @@ class DeformableDetrSinePositionEmbedding(nn.Module): ...@@ -491,7 +491,7 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
...@@ -617,7 +617,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ...@@ -617,7 +617,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
def _reset_parameters(self): def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0) nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = ( grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0]) (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
...@@ -1557,7 +1557,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): ...@@ -1557,7 +1557,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
temperature = 10000 temperature = 10000
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4 # batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale proposals = proposals.sigmoid() * scale
......
...@@ -71,7 +71,7 @@ class OpenLlamaRotaryEmbedding(nn.Module): ...@@ -71,7 +71,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
...@@ -81,7 +81,7 @@ class OpenLlamaRotaryEmbedding(nn.Module): ...@@ -81,7 +81,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
...@@ -110,7 +110,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -110,7 +110,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
...@@ -135,10 +135,10 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -135,10 +135,10 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
base = self.base * ( base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
......
...@@ -942,7 +942,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -942,7 +942,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
hids = [] hids = []
attentions = [] if output_attentions else None attentions = [] if output_attentions else None
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=torch.int64).type_as(
dtype=word_emb.dtype
)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq) pos_emb = self.pos_emb(pos_seq)
......
...@@ -401,7 +401,7 @@ class DetaSinePositionEmbedding(nn.Module): ...@@ -401,7 +401,7 @@ class DetaSinePositionEmbedding(nn.Module):
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
...@@ -526,7 +526,7 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -526,7 +526,7 @@ class DetaMultiscaleDeformableAttention(nn.Module):
def _reset_parameters(self): def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0) nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = ( grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0]) (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
...@@ -1447,7 +1447,7 @@ class DetaModel(DetaPreTrainedModel): ...@@ -1447,7 +1447,7 @@ class DetaModel(DetaPreTrainedModel):
temperature = 10000 temperature = 10000
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4 # batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale proposals = proposals.sigmoid() * scale
......
...@@ -435,7 +435,7 @@ class DetrSinePositionEmbedding(nn.Module): ...@@ -435,7 +435,7 @@ class DetrSinePositionEmbedding(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
......
...@@ -94,7 +94,7 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -94,7 +94,7 @@ class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int):
super().__init__() super().__init__()
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
......
...@@ -138,7 +138,7 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -138,7 +138,7 @@ class FalconRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
...@@ -148,7 +148,7 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -148,7 +148,7 @@ class FalconRotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
...@@ -177,7 +177,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -177,7 +177,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
...@@ -202,10 +202,10 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -202,10 +202,10 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
base = self.base * ( base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
......
...@@ -820,9 +820,9 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module): ...@@ -820,9 +820,9 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pos_enc_positive = torch.zeros(x.size(1), self.embed_dim) pos_enc_positive = torch.zeros(x.size(1), self.embed_dim)
pos_enc_negative = torch.zeros(x.size(1), self.embed_dim) pos_enc_negative = torch.zeros(x.size(1), self.embed_dim)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.embed_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embed_dim) torch.arange(0, self.embed_dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.embed_dim)
) )
pos_enc_positive[:, 0::2] = torch.sin(position * div_term) pos_enc_positive[:, 0::2] = torch.sin(position * div_term)
pos_enc_positive[:, 1::2] = torch.cos(position * div_term) pos_enc_positive[:, 1::2] = torch.cos(position * div_term)
......
...@@ -1346,8 +1346,8 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1346,8 +1346,8 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
# zero pad # zero pad
......
...@@ -235,8 +235,8 @@ class FunnelAttentionStructure(nn.Module): ...@@ -235,8 +235,8 @@ class FunnelAttentionStructure(nn.Module):
if self.config.attention_type == "factorized": if self.config.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula. # Notations from the paper, appending A.2.2, final formula.
# We need to create and return the matrices phi, psi, pi and omega. # We need to create and return the matrices phi, psi, pi and omega.
pos_seq = torch.arange(0, seq_len, 1.0, dtype=dtype, device=device) pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device) freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2))) inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
sinusoid = pos_seq[:, None] * inv_freq[None] sinusoid = pos_seq[:, None] * inv_freq[None]
sin_embed = torch.sin(sinusoid) sin_embed = torch.sin(sinusoid)
...@@ -252,17 +252,17 @@ class FunnelAttentionStructure(nn.Module): ...@@ -252,17 +252,17 @@ class FunnelAttentionStructure(nn.Module):
else: else:
# Notations from the paper, appending A.2.1, final formula. # Notations from the paper, appending A.2.1, final formula.
# We need to create and return all the possible vectors R for all blocks and shifts. # We need to create and return all the possible vectors R for all blocks and shifts.
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device) freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2))) inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
# Maximum relative positions for the first input # Maximum relative positions for the first input
rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype, device=device) rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
zero_offset = seq_len * 2 zero_offset = seq_len * 2
sinusoid = rel_pos_id[:, None] * inv_freq[None] sinusoid = rel_pos_id[:, None] * inv_freq[None]
sin_embed = self.sin_dropout(torch.sin(sinusoid)) sin_embed = self.sin_dropout(torch.sin(sinusoid))
cos_embed = self.cos_dropout(torch.cos(sinusoid)) cos_embed = self.cos_dropout(torch.cos(sinusoid))
pos_embed = torch.cat([sin_embed, cos_embed], dim=-1) pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)
pos = torch.arange(0, seq_len, dtype=dtype, device=device) pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
pooled_pos = pos pooled_pos = pos
position_embeds_list = [] position_embeds_list = []
for block_index in range(0, self.config.num_blocks): for block_index in range(0, self.config.num_blocks):
......
...@@ -684,8 +684,8 @@ class FuyuImageProcessor(BaseImageProcessor): ...@@ -684,8 +684,8 @@ class FuyuImageProcessor(BaseImageProcessor):
# Indices of image patches. # Indices of image patches.
patches_mask = subseq_image_input_ids == image_placeholder_id patches_mask = subseq_image_input_ids == image_placeholder_id
num_patches = torch.count_nonzero(patches_mask) num_patches = torch.count_nonzero(patches_mask)
indices = torch.arange( indices = torch.arange(num_patches, dtype=torch.int64, device=subseq_image_input_ids.device).type_as(
num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device subseq_image_input_ids
) )
# Place those indices in the image input ids token stream, with -1 representing non-index tokens. # Place those indices in the image input ids token stream, with -1 representing non-index tokens.
......
...@@ -534,7 +534,7 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -534,7 +534,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
...@@ -544,7 +544,7 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -544,7 +544,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
...@@ -573,7 +573,7 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -573,7 +573,7 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
...@@ -598,10 +598,10 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -598,10 +598,10 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
base = self.base * ( base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
......
...@@ -242,7 +242,7 @@ class RotaryEmbedding(nn.Module): ...@@ -242,7 +242,7 @@ class RotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
...@@ -252,7 +252,7 @@ class RotaryEmbedding(nn.Module): ...@@ -252,7 +252,7 @@ class RotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
......
...@@ -56,8 +56,8 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -56,8 +56,8 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
......
...@@ -477,7 +477,7 @@ class IdeficsEmbedding(torch.nn.Module): ...@@ -477,7 +477,7 @@ class IdeficsEmbedding(torch.nn.Module):
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
...@@ -487,7 +487,7 @@ class IdeficsEmbedding(torch.nn.Module): ...@@ -487,7 +487,7 @@ class IdeficsEmbedding(torch.nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
......
...@@ -774,8 +774,8 @@ class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module): ...@@ -774,8 +774,8 @@ class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
# zero pad # zero pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment