"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aa925a52fad9d6b98dac4c1b27f881bef7e88dad"
Unverified Commit 0d9328f2 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Patch GPU failures (#6281)

* Pin to 1.5.0

* Patch XLM GPU test
parent 80a0676a
...@@ -36,7 +36,7 @@ jobs: ...@@ -36,7 +36,7 @@ jobs:
run: | run: |
source .env/bin/activate source .env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install torch --no-cache-dir pip install torch!=1.6.0 --no-cache-dir
pip install .[sklearn,testing] pip install .[sklearn,testing]
- name: Are GPUs recognized by our DL frameworks - name: Are GPUs recognized by our DL frameworks
......
...@@ -496,11 +496,13 @@ class XLMModel(XLMPreTrainedModel): ...@@ -496,11 +496,13 @@ class XLMModel(XLMPreTrainedModel):
else: else:
bs, slen = inputs_embeds.size()[:-1] bs, slen = inputs_embeds.size()[:-1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if lengths is None: if lengths is None:
if input_ids is not None: if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long() lengths = (input_ids != self.pad_index).sum(dim=1).long()
else: else:
lengths = torch.LongTensor([slen] * bs) lengths = torch.tensor([slen] * bs, device=device)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment