Unverified Commit e2964b8a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt] no need to pass device (#7292)

parent e4b94d8e
...@@ -1150,13 +1150,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1150,13 +1150,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions, embedding_dim, padding_idx): def __init__(self, num_positions, embedding_dim, padding_idx):
self.make_weight(num_positions, embedding_dim, padding_idx) self.make_weight(num_positions, embedding_dim, padding_idx)
def make_weight(self, num_positions, embedding_dim, padding_idx, device=None): def make_weight(self, num_positions, embedding_dim, padding_idx):
weight = self.get_embedding(num_positions, embedding_dim, padding_idx) weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
if device is not None:
weight = weight.to(device)
if not hasattr(self, "weight"): if not hasattr(self, "weight"):
# in ___init__
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight) super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
else: else:
# in forward
weight = weight.to(self.weight.device)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
self.weight.detach_() self.weight.detach_()
self.weight.requires_grad = False self.weight.requires_grad = False
...@@ -1204,6 +1205,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1204,6 +1205,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weight.size(0): if max_pos > self.weight.size(0):
# expand embeddings if needed # expand embeddings if needed
self.make_weight(max_pos, self.embedding_dim, self.padding_idx, device=input.device) self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
positions = self.make_positions(input, self.padding_idx) positions = self.make_positions(input, self.padding_idx)
return super().forward(positions) return super().forward(positions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment