"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "c0336b751da4af0352b83e9bd4f397ad3e2b0f6a"
Unverified Commit 8dfc8c72 authored by Ethan Perez's avatar Ethan Perez Committed by GitHub
Browse files

Don't pass in token_type_ids to BART for GLUE (#8929)

Without this fix, training a `BARTForSequenceClassification` model with `run_pl_glue.py` gives `TypeError: forward() got an unexpected keyword argument 'token_type_ids'`, because BART does not have token_type_ids. I've solved this issue in the same way as it's solved for the "distilbert" model, and I can train BART models on SNLI without errors now.
parent df311a5c
...@@ -38,7 +38,7 @@ class GLUETransformer(BaseTransformer): ...@@ -38,7 +38,7 @@ class GLUETransformer(BaseTransformer):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if self.config.model_type != "distilbert": if self.config.model_type not in ["distilbert", "bart"]:
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None
outputs = self(**inputs) outputs = self(**inputs)
...@@ -101,7 +101,7 @@ class GLUETransformer(BaseTransformer): ...@@ -101,7 +101,7 @@ class GLUETransformer(BaseTransformer):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if self.config.model_type != "distilbert": if self.config.model_type not in ["distilbert", "bart"]:
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None
outputs = self(**inputs) outputs = self(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment