Unverified Commit aff87da1 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `ErnieMEmbeddings` device issue (#21726)



* remove .parameters()).device

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 2f2b19ff
...@@ -77,7 +77,7 @@ class ErnieMEmbeddings(nn.Module): ...@@ -77,7 +77,7 @@ class ErnieMEmbeddings(nn.Module):
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
if position_ids is None: if position_ids is None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
ones = torch.ones(input_shape, dtype=torch.int64) ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
seq_length = torch.cumsum(ones, dim=1) seq_length = torch.cumsum(ones, dim=1)
position_ids = seq_length - ones position_ids = seq_length - ones
...@@ -85,7 +85,6 @@ class ErnieMEmbeddings(nn.Module): ...@@ -85,7 +85,6 @@ class ErnieMEmbeddings(nn.Module):
position_ids = position_ids + past_key_values_length position_ids = position_ids + past_key_values_length
# to mimic paddlenlp implementation # to mimic paddlenlp implementation
position_ids += 2 position_ids += 2
position_ids = position_ids.to(next(self.position_embeddings.parameters()).device)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings embeddings = inputs_embeds + position_embeddings
embeddings = self.layer_norm(embeddings) embeddings = self.layer_norm(embeddings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment