Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] or [batch size, 1, 1, seq_len] and makes it binary
Args:
attention_mask (Tensor): The input attention mask
Returns:
Tensor: The extended binary attention mask
"""
# We create a 3D attention mask from a 2D tensor mask.
kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly on the GPU. Defaults to False