"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "55bc0c599a53c71028ce97613e44978681d68d14"
Commit eb8fda51 authored by thomwolf's avatar thomwolf
Browse files

update docstrings

parent e77721e4
...@@ -984,7 +984,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -984,7 +984,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size] `input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[ with the token indices selected in the range [0, self.config.n_token[
`mems`: optional memomry of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs: Outputs:
A tuple of (last_hidden_state, new_mems) A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model `last_hidden_state`: the encoded-hidden-states at the top of the model
...@@ -1220,6 +1222,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1220,6 +1222,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def forward(self, input_ids, mems=None): def forward(self, input_ids, mems=None):
""" Params: """ Params:
input_ids :: [len, bsz] input_ids :: [len, bsz]
mems :: optional mems from previous forwar passes (or init_mems)
list (num layers) of mem states at the entry of each layer
shape :: [self.config.mem_len, bsz, self.config.d_model]
Returns: Returns:
tuple (last_hidden, new_mems) where: tuple (last_hidden, new_mems) where:
new_mems: list (num layers) of mem states at the entry of each layer new_mems: list (num layers) of mem states at the entry of each layer
...@@ -1250,8 +1255,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1250,8 +1255,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size] `input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[ with the token indices selected in the range [0, self.config.n_token[
`target`: a torch.LongTensor of shape [sequence_length, batch_size] `target`: an optional torch.LongTensor of shape [sequence_length, batch_size]
with the target token indices selected in the range [0, self.config.n_token[ with the target token indices selected in the range [0, self.config.n_token[
`mems`: an optional memory of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Outputs: Outputs:
A tuple of (last_hidden_state, new_mems) A tuple of (last_hidden_state, new_mems)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment