constants.py 571 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import json


with open("/public/home/zhaoying1/work/TencentPretrain-main/models/llama_special_tokens_map.json", mode="r", encoding="utf-8") as f:
    special_tokens_map = json.load(f)

UNK_TOKEN = special_tokens_map["unk_token"]
CLS_TOKEN = special_tokens_map["cls_token"]
SEP_TOKEN = special_tokens_map["sep_token"]
MASK_TOKEN = special_tokens_map["mask_token"]
PAD_TOKEN = special_tokens_map["pad_token"]
try:
    # e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
    SENTINEL_TOKEN = special_tokens_map["sentinel_token"]
except KeyError:
    pass