"lightx2v/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "d242358fd31c3d6042f9abed5c7ad5d14a89b69b"
Commit ebd2cb8d authored by thomwolf's avatar thomwolf
Browse files

update from_pretrained to load XLNetModel as well

parent 483cbc36
import torch
from torch.nn import functional as F
from pytorch_pretrained_bert import XLNetModel, XLNetLMHeadModel, XLNetTokenizer
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokens = tokenizer.encode('I am very ')
for i in range(len(tokens), 20):
mask = torch.tensor([[[0.0] * i + [1.0]]])
logits, _ = model(torch.tensor([tokens + [0]]),
perm_mask=mask.expand(-1, i+1, -1),
target_mapping=mask,
inp_q=mask.squeeze(1))
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
tokens.append(output.item())
print(tokenizer.decode(tokens))
......@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = XLNetConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Update config with kwargs if needed
for key, value in kwargs:
if hasattr(config, key):
setattr(config, key, value)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_xlnet(model, resolved_archive_file)
return load_tf_weights_in_xlnet(model, config, resolved_archive_file)
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
......@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'xlnet') and any(s.startswith('xlnet.') for s in state_dict.keys()):
start_prefix = 'xlnet.'
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
......@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1)
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
......@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLNet paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Outputs: Tuple of (encoded_layers, pooled_output)
......
......@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = '▁'
SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class XLNetTokenizer(object):
"""
......@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
......
......@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
target_mapping[:, 0, -1] = 1.0 # predict last token
inp_q = target_mapping[:, 0, :].clone() # predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
# mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory
# from previous batches. The length of the list equals n_layer.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [len, len, bsz].
# If perm_mask[i, j, k] = 0, i attend to j in batch k;
# if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [num_predict, len, bsz].
# If target_mapping[i, j, k] = 1, the i-th predict in batch k is
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [len, bsz].
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
......@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
config.update(run_config)
return (config, input_ids_1, input_ids_2, segment_ids, lm_labels)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
all_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
all_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
logits, _ = model(input_ids_q,
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
outputs = {
"loss_1": loss_1,
"mems_1a": mems_1a,
"lm_logits_1": lm_logits_1,
"all_logits_1": all_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2,
"mems_2a": mems_2a,
"lm_logits_2": lm_logits_2,
"all_logits_2": all_logits_2,
"mems_2b": mems_2b,
}
return outputs
......@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
list(result["loss_1"].size()),
[])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
......@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
list(result["loss_2"].size()),
[])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment