Unverified Commit e2964b8a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt] no need to pass device (#7292)

parent e4b94d8e
......@@ -1150,13 +1150,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions, embedding_dim, padding_idx):
self.make_weight(num_positions, embedding_dim, padding_idx)
def make_weight(self, num_positions, embedding_dim, padding_idx, device=None):
def make_weight(self, num_positions, embedding_dim, padding_idx):
weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
if device is not None:
weight = weight.to(device)
if not hasattr(self, "weight"):
# in ___init__
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
else:
# in forward
weight = weight.to(self.weight.device)
self.weight = nn.Parameter(weight)
self.weight.detach_()
self.weight.requires_grad = False
......@@ -1204,6 +1205,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weight.size(0):
# expand embeddings if needed
self.make_weight(max_pos, self.embedding_dim, self.padding_idx, device=input.device)
self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
positions = self.make_positions(input, self.padding_idx)
return super().forward(positions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment