Unverified Commit d1ab1fab authored by Yuval Pinter's avatar Yuval Pinter Committed by GitHub
Browse files

pass langs parameter to certain XLM models (#2734)

* pass langs parameter to certain XLM models

Adding an argument that specifies the language the SQuAD dataset is in so language-sensitive XLMs (e.g. `xlm-mlm-tlm-xnli15-1024`) don't default to language `0`.
Allows resolution of issue #1799 .

* fixing from `make style`

* fixing style (again)
parent 9e5b549b
...@@ -219,6 +219,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -219,6 +219,11 @@ def train(args, train_dataset, model, tokenizer):
inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
if args.version_2_with_negative: if args.version_2_with_negative:
inputs.update({"is_impossible": batch[7]}) inputs.update({"is_impossible": batch[7]})
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
inputs.update(
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
)
outputs = model(**inputs) outputs = model(**inputs)
# model outputs are always tuple in transformers (see doc) # model outputs are always tuple in transformers (see doc)
loss = outputs[0] loss = outputs[0]
...@@ -330,6 +335,11 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -330,6 +335,11 @@ def evaluate(args, model, tokenizer, prefix=""):
# XLNet and XLM use more arguments for their predictions # XLNet and XLM use more arguments for their predictions
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[4], "p_mask": batch[5]}) inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
# for lang_id-sensitive xlm models
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
inputs.update(
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
)
outputs = model(**inputs) outputs = model(**inputs)
...@@ -635,6 +645,12 @@ def main(): ...@@ -635,6 +645,12 @@ def main():
help="If true, all of the warnings related to data processing will be printed. " help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.", "A number of warnings are expected for a normal SQuAD evaluation.",
) )
parser.add_argument(
"--lang_id",
default=0,
type=int,
help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
)
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment