transformer_xl_hubconf.py 5.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
from pytorch_transformers.modeling_transfo_xl import (
VictorSanh's avatar
VictorSanh committed
3
4
    TransfoXLModel,
    TransfoXLLMHeadModel
5
6
7
8
9
10
11
12
13
)

# A lot of models share the same param doc. Use a decorator
# to save typing
transformer_xl_docstring = """
    Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
    - you don't need to specify positioning embeddings indices
    - the tokens in the vocabulary have to be sorted to decreasing frequency.

VictorSanh's avatar
VictorSanh committed
14
15
16
17
18
19
20
21
22
23
24
25
    Params:
        pretrained_model_name_or_path: either:
            - a str with the name of a pre-trained model to load selected in the list of:
                . `transfo-xl-wt103`
            - a path or url to a pretrained model archive containing:
                . `transfo_xl_config.json` a configuration file for the model
                . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
            - a path or url to a pretrained model archive containing:
                . `transfo_xl_config.json` a configuration file for the model
                . `model.chkpt` a TensorFlow checkpoint
        from_tf: should we load the weights from a locally saved TensorFlow checkpoint
        cache_dir: an optional path to a folder in which the pre-trained models will be cached.
thomwolf's avatar
thomwolf committed
26
        state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
VictorSanh's avatar
VictorSanh committed
27
        *inputs, **kwargs: additional input for the specific TransformerXL class
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""


def _append_from_pretrained_docstring(docstr):
    def docstring_decorator(fn):
        fn.__doc__ = fn.__doc__ + docstr
        return fn
    return docstring_decorator


def transformerXLTokenizer(*args, **kwargs):
    """
    Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl

    Args:
    pretrained_model_name_or_path: Path to pretrained model archive
                                   or one of pre-trained vocab configs below.
                                       * transfo-xl-wt103

    Example:
VictorSanh's avatar
VictorSanh committed
48
        >>> import torch
thomwolf's avatar
thomwolf committed
49
        >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
VictorSanh's avatar
VictorSanh committed
50
51
        
        >>> text = "Who was Jim Henson ?"
52
        >>> tokenized_text = tokenizer.tokenize(tokenized_text)
VictorSanh's avatar
VictorSanh committed
53
        >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
54
55
56
57
58
59
60
61
    """
    tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs)
    return tokenizer


@_append_from_pretrained_docstring(transformer_xl_docstring)
def transformerXLModel(*args, **kwargs):
    """
VictorSanh's avatar
VictorSanh committed
62
    transformerXLModel is the basic Transformer XL model.
63
64
65

    Example:
        # Load the tokenizer
VictorSanh's avatar
VictorSanh committed
66
        >>> import torch
thomwolf's avatar
thomwolf committed
67
        >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
68
69
70

        #  Prepare tokenized input
        >>> text_1 = "Who was Jim Henson ?"
VictorSanh's avatar
VictorSanh committed
71
72
73
        >>> text_2 = "Jim Henson was a puppeteer"
        >>> tokenized_text_1 = tokenizer.tokenize(text_1)
        >>> tokenized_text_2 = tokenizer.tokenize(text_2)
74
75
        >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
        >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
VictorSanh's avatar
VictorSanh committed
76
77
        >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
        >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
78
79

        # Load transformerXLModel
thomwolf's avatar
thomwolf committed
80
        >>> model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLModel', 'transfo-xl-wt103')
81
82
83
        >>> model.eval()

        # Predict hidden states features for each layer
VictorSanh's avatar
VictorSanh committed
84
        # We can re-use the memory cells in a subsequent call to attend a longer context
85
86
        >>> with torch.no_grad():
                hidden_states_1, mems_1 = model(tokens_tensor_1)
VictorSanh's avatar
VictorSanh committed
87
                hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
88
89
90
91
92
93
94
95
    """
    model = TransfoXLModel.from_pretrained(*args, **kwargs)
    return model


@_append_from_pretrained_docstring(transformer_xl_docstring)
def transformerXLLMHeadModel(*args, **kwargs):
    """
VictorSanh's avatar
VictorSanh committed
96
    transformerXLModel is the basic Transformer XL model with the
VictorSanh's avatar
VictorSanh committed
97
    tied (pre-trained) language modeling head on top.
98

VictorSanh's avatar
VictorSanh committed
99
    Example:
100
        # Load the tokenizer
VictorSanh's avatar
VictorSanh committed
101
        >>> import torch
thomwolf's avatar
thomwolf committed
102
        >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
103
104
105

        #  Prepare tokenized input
        >>> text_1 = "Who was Jim Henson ?"
VictorSanh's avatar
VictorSanh committed
106
107
108
        >>> text_2 = "Jim Henson was a puppeteer"
        >>> tokenized_text_1 = tokenizer.tokenize(text_1)
        >>> tokenized_text_2 = tokenizer.tokenize(text_2)
109
110
        >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
        >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
VictorSanh's avatar
VictorSanh committed
111
112
        >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
        >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
113
114

        # Load transformerXLLMHeadModel
thomwolf's avatar
thomwolf committed
115
        >>> model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLLMHeadModel', 'transfo-xl-wt103')
116
117
118
        >>> model.eval()

        # Predict hidden states features for each layer
VictorSanh's avatar
VictorSanh committed
119
        # We can re-use the memory cells in a subsequent call to attend a longer context
120
121
        >>> with torch.no_grad():
                predictions_1, mems_1 = model(tokens_tensor_1)
VictorSanh's avatar
VictorSanh committed
122
                predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
123

VictorSanh's avatar
VictorSanh committed
124
125
126
127
        # Get the predicted last token
        >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
        >>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        >>> assert predicted_token == 'who'
128
129
130
    """
    model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs)
    return model