transformer_xl_hubconf.py 5.56 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
from pytorch_transformers.modeling_transfo_xl import (
VictorSanh's avatar
VictorSanh committed
3
4
    TransfoXLModel,
    TransfoXLLMHeadModel
5
6
7
8
9
10
11
12
13
)

# A lot of models share the same param doc. Use a decorator
# to save typing
transformer_xl_docstring = """
    Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
    - you don't need to specify positioning embeddings indices
    - the tokens in the vocabulary have to be sorted to decreasing frequency.

VictorSanh's avatar
VictorSanh committed
14
15
16
17
18
19
20
21
22
23
24
25
    Params:
        pretrained_model_name_or_path: either:
            - a str with the name of a pre-trained model to load selected in the list of:
                . `transfo-xl-wt103`
            - a path or url to a pretrained model archive containing:
                . `transfo_xl_config.json` a configuration file for the model
                . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
            - a path or url to a pretrained model archive containing:
                . `transfo_xl_config.json` a configuration file for the model
                . `model.chkpt` a TensorFlow checkpoint
        from_tf: should we load the weights from a locally saved TensorFlow checkpoint
        cache_dir: an optional path to a folder in which the pre-trained models will be cached.
thomwolf's avatar
thomwolf committed
26
        state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
VictorSanh's avatar
VictorSanh committed
27
        *inputs, **kwargs: additional input for the specific TransformerXL class
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""


def _append_from_pretrained_docstring(docstr):
    def docstring_decorator(fn):
        fn.__doc__ = fn.__doc__ + docstr
        return fn
    return docstring_decorator


def transformerXLTokenizer(*args, **kwargs):
    """
    Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl

    Args:
    pretrained_model_name_or_path: Path to pretrained model archive
                                   or one of pre-trained vocab configs below.
                                       * transfo-xl-wt103

    Example:
thomwolf's avatar
thomwolf committed
48
49
        import torch
        tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
VictorSanh's avatar
VictorSanh committed
50
        
thomwolf's avatar
thomwolf committed
51
52
53
        text = "Who was Jim Henson ?"
        tokenized_text = tokenizer.tokenize(tokenized_text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
54
55
56
57
58
59
60
61
    """
    tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs)
    return tokenizer


@_append_from_pretrained_docstring(transformer_xl_docstring)
def transformerXLModel(*args, **kwargs):
    """
VictorSanh's avatar
VictorSanh committed
62
    transformerXLModel is the basic Transformer XL model.
63
64
65

    Example:
        # Load the tokenizer
thomwolf's avatar
thomwolf committed
66
67
        import torch
        tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
68
69

        #  Prepare tokenized input
thomwolf's avatar
thomwolf committed
70
71
72
73
74
75
76
77
        text_1 = "Who was Jim Henson ?"
        text_2 = "Jim Henson was a puppeteer"
        tokenized_text_1 = tokenizer.tokenize(text_1)
        tokenized_text_2 = tokenizer.tokenize(text_2)
        indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
        indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
        tokens_tensor_1 = torch.tensor([indexed_tokens_1])
        tokens_tensor_2 = torch.tensor([indexed_tokens_2])
78
79

        # Load transformerXLModel
thomwolf's avatar
thomwolf committed
80
81
        model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLModel', 'transfo-xl-wt103')
        model.eval()
82
83

        # Predict hidden states features for each layer
VictorSanh's avatar
VictorSanh committed
84
        # We can re-use the memory cells in a subsequent call to attend a longer context
thomwolf's avatar
thomwolf committed
85
        with torch.no_grad():
86
                hidden_states_1, mems_1 = model(tokens_tensor_1)
VictorSanh's avatar
VictorSanh committed
87
                hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
88
89
90
91
92
93
94
95
    """
    model = TransfoXLModel.from_pretrained(*args, **kwargs)
    return model


@_append_from_pretrained_docstring(transformer_xl_docstring)
def transformerXLLMHeadModel(*args, **kwargs):
    """
VictorSanh's avatar
VictorSanh committed
96
    transformerXLModel is the basic Transformer XL model with the
VictorSanh's avatar
VictorSanh committed
97
    tied (pre-trained) language modeling head on top.
98

VictorSanh's avatar
VictorSanh committed
99
    Example:
100
        # Load the tokenizer
thomwolf's avatar
thomwolf committed
101
102
        import torch
        tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103')
103
104

        #  Prepare tokenized input
thomwolf's avatar
thomwolf committed
105
106
107
108
109
110
111
112
        text_1 = "Who was Jim Henson ?"
        text_2 = "Jim Henson was a puppeteer"
        tokenized_text_1 = tokenizer.tokenize(text_1)
        tokenized_text_2 = tokenizer.tokenize(text_2)
        indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
        indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
        tokens_tensor_1 = torch.tensor([indexed_tokens_1])
        tokens_tensor_2 = torch.tensor([indexed_tokens_2])
113
114

        # Load transformerXLLMHeadModel
thomwolf's avatar
thomwolf committed
115
116
        model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLLMHeadModel', 'transfo-xl-wt103')
        model.eval()
117
118

        # Predict hidden states features for each layer
VictorSanh's avatar
VictorSanh committed
119
        # We can re-use the memory cells in a subsequent call to attend a longer context
thomwolf's avatar
thomwolf committed
120
        with torch.no_grad():
121
                predictions_1, mems_1 = model(tokens_tensor_1)
VictorSanh's avatar
VictorSanh committed
122
                predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
123

VictorSanh's avatar
VictorSanh committed
124
        # Get the predicted last token
thomwolf's avatar
thomwolf committed
125
126
127
        predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
        predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        assert predicted_token == 'who'
128
129
130
    """
    model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs)
    return model