Commit 18107563 authored by thomwolf's avatar thomwolf
Browse files

updating model loading and adding special tokens ids

parent ebd2cb8d
...@@ -6,14 +6,13 @@ import logging ...@@ -6,14 +6,13 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased') model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokens = tokenizer.encode('I am very ') tokens = tokenizer.encode('I am very happy')
for i in range(len(tokens), 20): for i in range(len(tokens), 20):
mask = torch.tensor([[[0.0] * i + [1.0]]]) mask = torch.tensor([[[0.0] * i + [1.0]]])
logits, _ = model(torch.tensor([tokens + [0]]), logits, _ = model(torch.tensor([tokens + [0]]),
perm_mask=mask.expand(-1, i+1, -1), # perm_mask=mask.expand(-1, i+1, -1),
target_mapping=mask, target_mapping=mask,
inp_q=mask.squeeze(1)) inp_q=mask.squeeze(1))
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1) output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
......
...@@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module):
# Load config # Load config
config = XLNetConfig.from_json_file(resolved_config_file) config = XLNetConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Update config with kwargs if needed # Update config with kwargs if needed
for key, value in kwargs: to_remove = []
for key, value in kwargs.items():
if hasattr(config, key): if hasattr(config, key):
setattr(config, key, value) setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
......
...@@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { ...@@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_NAME = 'spiece.model' VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = '▁' SPIECE_UNDERLINE = u'▁'
# Tokens
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
# Segments (not really needed)
SEG_ID_A = 0 SEG_ID_A = 0
SEG_ID_B = 1 SEG_ID_B = 1
SEG_ID_CLS = 2 SEG_ID_CLS = 2
......
xlnet @ cbdedecb
Subproject commit cbdedecbc7951fc000a1547f9feb086c34f0698b
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment