Unverified Commit ec026747 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix FlauBERT GPU test (#6142)

* Fix GPU test

* Remove legacy constructor
parent 91cb9546
...@@ -163,11 +163,13 @@ class FlaubertModel(XLMModel): ...@@ -163,11 +163,13 @@ class FlaubertModel(XLMModel):
else: else:
bs, slen = inputs_embeds.size()[:-1] bs, slen = inputs_embeds.size()[:-1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if lengths is None: if lengths is None:
if input_ids is not None: if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long() lengths = (input_ids != self.pad_index).sum(dim=1).long()
else: else:
lengths = torch.LongTensor([slen] * bs) lengths = torch.tensor([slen] * bs, device=device)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
...@@ -184,8 +186,6 @@ class FlaubertModel(XLMModel): ...@@ -184,8 +186,6 @@ class FlaubertModel(XLMModel):
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids # position_ids
if position_ids is None: if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device) position_ids = torch.arange(slen, dtype=torch.long, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment