Unverified Commit c6646613 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `get_embedding` dtype at init. time (#19473)



* cast positions dtype in XGLMModel

* Get the correct dtype at init time

* Get the correct dtype at init time
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e38cf93e
......@@ -152,7 +152,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(
......
......@@ -165,7 +165,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
......
......@@ -111,7 +111,7 @@ class Speech2Text2SinusoidalPositionalEmbedding(nn.Module):
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
......
......@@ -126,7 +126,7 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
......
......@@ -194,7 +194,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment