Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0c7ff57
Commit
c0c7ff57
authored
Jun 01, 2019
by
VictorSanh
Browse files
add transformer xl compatibility for torchhub
parent
48a58646
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
6 deletions
+143
-6
hubconfs/transformer_xl_hubconf.py
hubconfs/transformer_xl_hubconf.py
+132
-0
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+11
-6
No files found.
hubconfs/transformer_xl_hubconf.py
0 → 100644
View file @
c0c7ff57
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
TransfoXLTokenizer
from
pytorch_pretrained_bert.modeling_transfo_xl
import
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
# A lot of models share the same param doc. Use a decorator
# to save typing
transformer_xl_docstring
=
"""
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
def
_append_from_pretrained_docstring
(
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
docstr
return
fn
return
docstring_decorator
def
transformerXLTokenizer
(
*
args
,
**
kwargs
):
"""
Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* transfo-xl-wt103
Example:
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
>>> text = "Who was Jim Henson ?"
>>> tokenized_text = tokenizer.tokenize(tokenized_text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
"""
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
*
args
,
**
kwargs
)
return
tokenizer
@
_append_from_pretrained_docstring
(
transformer_xl_docstring
)
def
transformerXLModel
(
*
args
,
**
kwargs
):
"""
gpt2Model is the basic OpenAI GPT-2 Transformer model based on
identical stacked masked self-attention blocks and pre-trained
on large scale dataset using language modeling signal.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> tokenized_text_1 = tokenizer.tokenize(text_1)
>>> tokenized_text_2 = tokenizer.tokenize(text_2)
>>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
>>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load transformerXLModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLModel', 'transfo-xl-wt103')
>>> model.eval()
# Predict hidden states features for each layer
# We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad():
hidden_states_1, mems_1 = model(tokens_tensor_1)
hidden_states_2, past = model(tokens_tensor_2, past=past)
"""
model
=
TransfoXLModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
transformer_xl_docstring
)
def
transformerXLLMHeadModel
(
*
args
,
**
kwargs
):
"""
gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the
tied (pre-trained) language modeling head on top.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> tokenized_text_1 = tokenizer.tokenize(text_1)
>>> tokenized_text_2 = tokenizer.tokenize(text_2)
>>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
>>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load transformerXLLMHeadModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLLMHeadModel', 'transfo-xl-wt103')
>>> model.eval()
# Predict hidden states features for each layer
# We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad():
predictions_1, mems_1 = model(tokens_tensor_1)
predictions_2, past = model(tokens_tensor_2, past=past)
# Get the predicted last token
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
>>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
>>> assert predicted_token == 'who'
"""
model
=
TransfoXLLMHeadModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
c0c7ff57
...
...
@@ -888,8 +888,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -897,19 +896,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
. `transfo-xl
-wt103
`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `
bert
_config.json` a configuration file for the model
. `
transfo_xl
_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment