"""Retro project directory, which contains the preprocessed data for for pretraining. This
directory is built during preprocessing (see tools/retro/README.md), and contains
subdirectories for the chunk database and pretraining neighbors.
"""
retro_block_size:int=None
"""Number of records to load per data file, as saved during preprocessing. Block processing is
used for efficient data preprocessing.
"""
retro_chunk_length:int=None
"""Chunk length used for performing chunked- cross-attention (CCA)."""
retro_encoder_num_layers:int=2
"""Number of layers to use for the retrieval encoder."""
retro_encoder_hidden_dropout:float=0.1
"""Hidden dropout for retrieval encoder."""
retro_encoder_attention_dropout:float=0.1
"""Attention dropout for retrieval encoder."""
retro_neighbor_dirs:dict=None
"""Directory names of saved neighbor id files for train, valid, and test datasets."""
retro_num_neighbors:int=2
"""Number of neighbors to retrieve during pretraining."""
retro_num_retrieved_chunks:int=2
"""Number of chunks to retrieve from the retrieval database."""
retro_retrieved_length:int=None
"""Cached value of retro_num_retrieved_chunks * retro_chunk_length (i.e., the total number of
retrieved tokens; neighbor + continuation).
"""
retro_split_preprocessing:str=None
"""Data split used during data preprocessing."""
retro_verify_neighbor_count:bool=True
"""Verify that len(GPT dataset) == len(saved neighbors)."""
# pylint: disable=line-too-long
def__post_init__(self)->None:
"""Validate Retro config."""
super().__post_init__()
# Validate Transformer Engine version.
ifis_te_min_version("1.3"):
try:
assertos.getenv("NVTE_FLASH_ATTN")=="0"
assertos.getenv("NVTE_FUSED_ATTN")=="0"
exceptExceptionase:
raiseException(
"When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN and NVTE_FUSED_ATTN most both be defined and set to '0'. Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s."
decoder_input (Tensor): When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor): The labels of dimension [batch size, seq length].
inference_params (InferenceParams): Parameters for inference.
Returns:
Output tensor of forward pass.
"""
# Argument shapes:
# Notation:
# ns : Sequence length.
# bs : Batch size.
# d : Hidden size.
# l : Number of chunks per sample (i.e., seq_length/chunk_length).
# k : Number of neighbors.
# r : Number of retrieved tokens (neighbors + continuation).
# - input_ids: [ bs, ns ]
# - context_ids: [ k*bs*l, r ]
# - context: [ r, k*bs*l, d ]
# - output: [ ns, bs, d ]
# Context embedding (e.g., for Retro neighbor tokens).