Commit eb8fda51 authored by thomwolf's avatar thomwolf
Browse files

update docstrings

parent e77721e4
......@@ -984,7 +984,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`mems`: optional memomry of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs:
A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model
......@@ -1220,6 +1222,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def forward(self, input_ids, mems=None):
""" Params:
input_ids :: [len, bsz]
mems :: optional mems from previous forwar passes (or init_mems)
list (num layers) of mem states at the entry of each layer
shape :: [self.config.mem_len, bsz, self.config.d_model]
Returns:
tuple (last_hidden, new_mems) where:
new_mems: list (num layers) of mem states at the entry of each layer
......@@ -1250,8 +1255,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`target`: a torch.LongTensor of shape [sequence_length, batch_size]
`target`: an optional torch.LongTensor of shape [sequence_length, batch_size]
with the target token indices selected in the range [0, self.config.n_token[
`mems`: an optional memory of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs:
A tuple of (last_hidden_state, new_mems)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment