Unverified Commit 304aacac authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

🚨🚨🚨 [`Pix2Struct`] Attempts to fix training issues 🚨🚨🚨 (#23004)

* multiple fixes

- add `add_special_tokens` to `True` by default
- remove label smoothing and labels masking

* fix test
parent ba0dc545
......@@ -1554,10 +1554,9 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean", label_smoothing=0.1)
masked_labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), masked_labels.contiguous().view(-1))
loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
if not return_dict:
return tuple(
......
......@@ -49,7 +49,7 @@ class Pix2StructProcessor(ProcessorMixin):
self,
images=None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = False,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
......
......@@ -108,7 +108,7 @@ class Pix2StructProcessorTest(unittest.TestCase):
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str, return_token_type_ids=False, add_special_tokens=False)
encoded_tok = tokenizer(input_str, return_token_type_ids=False, add_special_tokens=True)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment