"When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN and NVTE_FUSED_ATTN most both be defined and set to '0'. Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s."
decoder_input (Tensor): When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor): The labels of dimension [batch size, seq length].
inference_params (InferenceParams): Parameters for inference.
Returns:
Output tensor of forward pass.
"""
# Argument shapes:
# Notation:
# ns : Sequence length.
# bs : Batch size.
# d : Hidden size.
# l : Number of chunks per sample (i.e., seq_length/chunk_length).
# k : Number of neighbors.
# r : Number of retrieved tokens (neighbors + continuation).
# - input_ids: [ bs, ns ]
# - context_ids: [ k*bs*l, r ]
# - context: [ r, k*bs*l, d ]
# - output: [ ns, bs, d ]
# Context embedding (e.g., for Retro neighbor tokens).
f'setting number of micro-batches to constant {num_microbatches_calculator.get()}'
)
# Batch size ramp up num micro-batches.
else:
assertlen(rampup_batch_size)==3,(
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_global_batch_size=int(rampup_batch_size[0])
batch_size_increment=int(rampup_batch_size[1])
ramup_samples=int(rampup_batch_size[2])
ifrank==0:
logger.info(
f'will use batch size rampup starting from global batch size {start_global_batch_size} to global batch size {global_batch_size} with batch size increments {batch_size_increment} over {ramup_samples} samples.'