- randomly choose prefix index when there's no constraint
- check that prefix are compatible with convention.
:param data: torch.Tensor
:param eod_token: int, token_id used to signal end of document
:param partial_prefix_indices: this agument can have multiple types:
- None, it signals that all prefix indices are randomly sampled.
- List[Optional[int]], its length has to be equal to mini batch size. It stores all the indices for per row prefix.
Optional means that if set to None, we allows ourselves to sample one randomly.
- List[List[Optional[int]]], it follows the following rules:
- The first dimension refers to that sample, ie len(partial_prefix_indices) == len(data)
- The second dimension refers to the number of document of that sample, ie
len(partial_prefix_indices[b]) == (data[b] == eod_token).sum() (+1 for the last partial document).
- partial_prefix_indices have to be interleaved with eod_indices, ie
eod_indices[b][d-1] < partial_prefix_indices[b][d] < eod_indices[b][d] + 1 or is None.
- Optional means that if set to None, we allows ourselves to sample one randomly.
:param reset_attention_mask: bool, determines if prefixes are to be per document or per row.
:return Depending if prefix is per document or per row, the method returns:
- List[List[int]]: prefix indices for each document in case of per document prefix
- List[int]: prefix indices for rows else.
"""
micro_batch_size,seq_length=data.size()
prefix_indices=[]
assertpartial_prefix_indicesisNoneorlen(partial_prefix_indices)==micro_batch_size,f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}"
assertpartial_prefix_indicesisNoneorlen(partial_prefix_indices[batch_id])==len(eod_indices),f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents"
fordoc_id,eod_indexinenumerate(eod_indices):
assertpartial_prefix_indicesisNoneorisinstance(partial_prefix_indices[batch_id],list),f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}"
# Prefix index is defined as the first index that isn't attended by all tokens in a document
assertprev_index+1<=prefix_index<=eod_index,f"Prefix index needs to be between documents indices, {prev_index+1} <= {prefix_index} <= {eod_index} should be True."