Unverified Commit 7e451193 authored by Geewook Kim's avatar Geewook Kim Committed by GitHub
Browse files

feat: add categorical special tokens (optional), related to #10

parent dd12dae5
......@@ -61,6 +61,18 @@ def train(config):
datasets = {"train": [], "validation": []}
for i, dataset_name_or_path in enumerate(config.dataset_name_or_paths):
task_name = os.path.basename(dataset_name_or_path) # e.g., cord-v2, docvqa, rvlcdip, ...
# add categorical special tokens (optional)
if task_name == "rvlcdip":
model_module.model.decoder.add_special_tokens([
"<advertisement/>", "<budget/>", "<email/>", "<file_folder/>",
"<form/>", "<handwritten/>", "<invoice/>", "<letter/>",
"<memo/>", "<news_article/>", "<presentation/>", "<questionnaire/>",
"<resume/>", "<scientific_publication/>", "<scientific_report/>", "<specification/>"
])
if task_name == "docvqa":
model_module.model.decoder.add_special_tokens(["<yes/>", "<no/>"])
for split in ["train", "validation"]:
datasets[split].append(
DonutDataset(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment