Commit ebd2cb8d authored by thomwolf's avatar thomwolf
Browse files

update from_pretrained to load XLNetModel as well

parent 483cbc36
import torch
from torch.nn import functional as F
from pytorch_pretrained_bert import XLNetModel, XLNetLMHeadModel, XLNetTokenizer
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokens = tokenizer.encode('I am very ')
for i in range(len(tokens), 20):
mask = torch.tensor([[[0.0] * i + [1.0]]])
logits, _ = model(torch.tensor([tokens + [0]]),
perm_mask=mask.expand(-1, i+1, -1),
target_mapping=mask,
inp_q=mask.squeeze(1))
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
tokens.append(output.item())
print(tokenizer.decode(tokens))
...@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format( logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file)) config_file, resolved_config_file))
# Load config # Load config
config = XLNetConfig.from_json_file(resolved_config_file) config = XLNetConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Update config with kwargs if needed
for key, value in kwargs:
if hasattr(config, key):
setattr(config, key, value)
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
return load_tf_weights_in_xlnet(model, resolved_archive_file) return load_tf_weights_in_xlnet(model, config, resolved_archive_file)
# Load from a PyTorch state_dict # Load from a PyTorch state_dict
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
...@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + '.')
start_prefix = '' start_prefix = ''
if not hasattr(model, 'xlnet') and any(s.startswith('xlnet.') for s in state_dict.keys()): if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'xlnet.' start_prefix = 'transformer.'
load(model, prefix=start_prefix) load(model, prefix=start_prefix)
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format( logger.info("Weights of {} not initialized from pretrained model: {}".format(
...@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
output_h = self.dropout(word_emb_k) output_h = self.dropout(word_emb_k)
if inp_q is not None: if inp_q is not None:
if target_mapping is not None: if target_mapping is not None:
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1) word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
else: else:
inp_q_ext = inp_q[:, :, None] inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q) output_g = self.dropout(word_emb_q)
else: else:
output_g = None output_g = None
...@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
`extract_features.py`, `run_classifier.py` and `run_squad.py`) input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 0 for real tokens and 1 for padding.
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
a `sentence B` token (see XLNet paper for more details). from previous batches. The length of the list equals n_layer.
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices If None, no memory is used.
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
input sequence length in the current batch. It's the mask that we typically use for attention when If perm_mask[k, i, j] = 0, i attend to j in batch k;
a batch has varying length sentences. if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. If None, each position attends to all the others.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Outputs: Tuple of (encoded_layers, pooled_output) Outputs: Tuple of (encoded_layers, pooled_output)
......
...@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model' ...@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = '▁' SPIECE_UNDERLINE = '▁'
SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class XLNetTokenizer(object): class XLNetTokenizer(object):
""" """
...@@ -52,6 +57,16 @@ class XLNetTokenizer(object): ...@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None special_tokens_file = None
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else: else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
......
...@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase): ...@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs. input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
# input_mask: float32 Tensor in shape [len, bsz], the input mask. perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
target_mapping[:, 0, -1] = 1.0 # predict last token
inp_q = target_mapping[:, 0, :].clone() # predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding. # 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory # mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory
# from previous batches. The length of the list equals n_layer. # from previous batches. The length of the list equals n_layer.
# If None, no memory is used. # If None, no memory is used.
# perm_mask: float32 Tensor in shape [len, len, bsz]. # perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[i, j, k] = 0, i attend to j in batch k; # If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[i, j, k] = 1, i does not attend to j in batch k. # if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others. # If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [num_predict, len, bsz]. # target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[i, j, k] = 1, the i-th predict in batch k is # If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token. # on the j-th token.
# Only used during pretraining for partial prediction. # Only used during pretraining for partial prediction.
# Set to None during finetuning. # Set to None during finetuning.
# inp_q: float32 Tensor in shape [len, bsz]. # inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses. # 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention. # Only used during pretraining for two-stream attention.
# Set to None during finetuning. # Set to None during finetuning.
...@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase): ...@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
config.update(run_config) config.update(run_config)
return (config, input_ids_1, input_ids_2, segment_ids, lm_labels) return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
def set_seed(self): def set_seed(self):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels): def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.eval() model.eval()
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels) loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids) all_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a) loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b) all_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
logits, _ = model(input_ids_q,
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
outputs = { outputs = {
"loss_1": loss_1, "loss_1": loss_1,
"mems_1a": mems_1a, "mems_1a": mems_1a,
"lm_logits_1": lm_logits_1, "all_logits_1": all_logits_1,
"mems_1b": mems_1b, "mems_1b": mems_1b,
"loss_2": loss_2, "loss_2": loss_2,
"mems_2a": mems_2a, "mems_2a": mems_2a,
"lm_logits_2": lm_logits_2, "all_logits_2": all_logits_2,
"mems_2b": mems_2b, "mems_2b": mems_2b,
} }
return outputs return outputs
...@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
list(result["loss_1"].size()), list(result["loss_1"].size()),
[]) [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits_1"].size()), list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]), list(list(mem.size()) for mem in result["mems_1a"]),
...@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
list(result["loss_2"].size()), list(result["loss_2"].size()),
[]) [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits_2"].size()), list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]), list(list(mem.size()) for mem in result["mems_2a"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment