You need to sign in or sign up before continuing.
Unverified Commit 66a5a6fd authored by Ashwin Geet Dsa's avatar Ashwin Geet Dsa Committed by GitHub
Browse files

fix to ensure that returned tensors after the tokenization is Long (#7039)



* fix to ensure that returned tensors after the tokenization is Long

* fix to ensure that returned tensors after the tokenization is Long
Co-authored-by: default avatarAshwin Geet Dsa <adsa@grvingt-6.nancy.grid5000.fr>
parent 9ccdb1d5
...@@ -149,7 +149,7 @@ class DataCollatorForLanguageModeling: ...@@ -149,7 +149,7 @@ class DataCollatorForLanguageModeling:
) -> torch.Tensor: ) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors # In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)): if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples] examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0) length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length: if are_tensors_same_length:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment