Unverified Commit c94cd097 authored by Avelina9X's avatar Avelina9X Committed by GitHub
Browse files

Updated missing docstrings for args and returns in bert_padding.py (#795)

* Updated docstrings of bert_padding.py

Added docstrings for missing arguments in the unpad and pad methods.

* Update bert_padding.py

Fixed spelling mistakes
parent ffc8682d
...@@ -102,6 +102,7 @@ def unpad_input(hidden_states, attention_mask): ...@@ -102,6 +102,7 @@ def unpad_input(hidden_states, attention_mask):
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return: Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int max_seqlen_in_batch: int
""" """
...@@ -170,6 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng ...@@ -170,6 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return: Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int max_seqlen_in_batch: int
""" """
...@@ -198,7 +200,9 @@ def pad_input(hidden_states, indices, batch, seqlen): ...@@ -198,7 +200,9 @@ def pad_input(hidden_states, indices, batch, seqlen):
""" """
Arguments: Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz) indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return: Return:
hidden_states: (batch, seqlen, ...) hidden_states: (batch, seqlen, ...)
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment