Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0c7ff57
Commit
c0c7ff57
authored
Jun 01, 2019
by
VictorSanh
Browse files
add transformer xl compatibility for torchhub
parent
48a58646
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
6 deletions
+143
-6
hubconfs/transformer_xl_hubconf.py
hubconfs/transformer_xl_hubconf.py
+132
-0
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+11
-6
No files found.
hubconfs/transformer_xl_hubconf.py
0 → 100644
View file @
c0c7ff57
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
from
pytorch_pretrained_bert.modeling_transfo_xl
import
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
# A lot of models share the same param doc. Use a decorator
# to save typing
transformer_xl_docstring
=
"""
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
def
_append_from_pretrained_docstring
(
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
docstr
return
fn
return
docstring_decorator
def
transformerXLTokenizer
(
*
args
,
**
kwargs
):
"""
Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* transfo-xl-wt103
Example:
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
>>> text = "Who was Jim Henson ?"
>>> tokenized_text = tokenizer.tokenize(tokenized_text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
"""
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
*
args
,
**
kwargs
)
return
tokenizer
@
_append_from_pretrained_docstring
(
transformer_xl_docstring
)
def
transformerXLModel
(
*
args
,
**
kwargs
):
"""
gpt2Model is the basic OpenAI GPT-2 Transformer model based on
identical stacked masked self-attention blocks and pre-trained
on large scale dataset using language modeling signal.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> tokenized_text_1 = tokenizer.tokenize(text_1)
>>> tokenized_text_2 = tokenizer.tokenize(text_2)
>>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
>>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load transformerXLModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLModel', 'transfo-xl-wt103')
>>> model.eval()
# Predict hidden states features for each layer
# We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad():
hidden_states_1, mems_1 = model(tokens_tensor_1)
hidden_states_2, past = model(tokens_tensor_2, past=past)
"""
model
=
TransfoXLModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
transformer_xl_docstring
)
def
transformerXLLMHeadModel
(
*
args
,
**
kwargs
):
"""
gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the
tied (pre-trained) language modeling head on top.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> tokenized_text_1 = tokenizer.tokenize(text_1)
>>> tokenized_text_2 = tokenizer.tokenize(text_2)
>>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
>>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load transformerXLLMHeadModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLLMHeadModel', 'transfo-xl-wt103')
>>> model.eval()
# Predict hidden states features for each layer
# We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad():
predictions_1, mems_1 = model(tokens_tensor_1)
predictions_2, past = model(tokens_tensor_2, past=past)
# Get the predicted last token
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
>>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
>>> assert predicted_token == 'who'
"""
model
=
TransfoXLLMHeadModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
c0c7ff57
...
...
@@ -888,8 +888,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -897,19 +896,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
. `transfo-xl
-wt103
`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `
bert
_config.json` a configuration file for the model
. `
transfo_xl
_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment