"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9cf7b23b9bda5ae0e827993e8154d17065ef8dab"
Unverified Commit 0e774e57 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Update readme

Adding details on how to extract a full list of hidden states for the Transformer-XL
parent c35d9d48
...@@ -624,6 +624,18 @@ This model *outputs* a tuple of (last_hidden_state, new_mems) ...@@ -624,6 +624,18 @@ This model *outputs* a tuple of (last_hidden_state, new_mems)
- `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model] - `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`. - `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
##### Extracting a list of the hidden states at each layer of the Transformer-XL from `last_hidden_state` and `new_mems`:
The `new_mems` contain all the hidden states PLUS the output of the embeddings (`new_mems[0]`). `new_mems[-1]` is the output of the hidden state of the layer below the last layer and `last_hidden_state` is the output of the last layer (i.E. the input of the softmax when we have a language modeling head on top).
There are two differences between the shapes of `new_mems` and `last_hidden_state`: `new_mems` have transposed first dimensions and are longer (of size `self.config.mem_len`). Here is how to extract the full list of hidden states from the model output:
```python
hidden_states, mems = model(tokens_tensor)
seq_length = hidden_states.size(1)
lower_hidden_states = list(t[-seq_length:, ...].transpose(0, 1) for t in mems)
all_hidden_states = lower_hidden_states + [hidden_states]
```
#### 13. `TransfoXLLMHeadModel` #### 13. `TransfoXLLMHeadModel`
`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings. `TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.
...@@ -637,7 +649,6 @@ This model *outputs* a tuple of (last_hidden_state, new_mems) ...@@ -637,7 +649,6 @@ This model *outputs* a tuple of (last_hidden_state, new_mems)
- else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens] - else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`. - `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
### Tokenizers: ### Tokenizers:
#### `BertTokenizer` #### `BertTokenizer`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment