"examples/vscode:/vscode.git/clone" did not exist on "28ba345eccb8a7af3e044f3dd82c1d661a065d80"
Unverified Commit a722c301 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[SinusoidalPositionalEmbedding] incorrect dtype when make_weights in forward (#13665)

parent 1417978c
......@@ -1272,8 +1272,8 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
# in ___init__
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
else:
# in forward
weight = weight.to(self.weight.device)
# in forward put the weights on the correct dtype and device of the param
weight = weight.to(dtype=self.weight.dtype, device=self.weight.device)
self.weight = nn.Parameter(weight)
self.weight.detach_()
self.weight.requires_grad = False
......
......@@ -126,8 +126,8 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward, put the weights on correct device
emb_weights = emb_weights.to(self.weights.device)
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
......
......@@ -149,8 +149,8 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward, put the weights on correct device
emb_weights = emb_weights.to(self.weights.device)
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
......
......@@ -90,8 +90,8 @@ class Speech2Text2SinusoidalPositionalEmbedding(nn.Module):
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward, put the weights on correct device
emb_weights = emb_weights.to(self.weights.device)
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment