"""Wraps the given function `f` to support TransformerEngine quantization.
This method does a couple things:
1. Wraps the given function in a Flax linen module. This module does not store any Flax parameters
but can store Flax variables for quantizers if required by the recipe.
2. When the wrapper is called, it provides an additional argument to the given function `f`, 'generate_quantizer_set' as the first argument. 'generate_quantizer_set' is a function that can be called to generate a TransformerEngine/JAX quantizer set object used in TransformerEngine/JAX APIs. 'generate_quantizer_set' will generate quantizers based on the recipe of this TransformerEngineQuantizer object.
Args:
f: The function to wrap. The first argument must be 'generate_quantizer_set'.
name: The name of this wrapped operation. If unspecified, will use `f.__name__`.
Returns:
A Flax linen module that wraps the given function.
"""Creates a Flax module class that performs a dot_general operation with the arguments x and kernel using the given quantization recipe.
This is intended for usage when you already have model parameters initialized and sharded for the kernel weights and you want to replace the GEMM implementation with TE's quantized GEMM using a given recipe.
The data type used to allocate the initial parameters.
This dtype is deprecated and will be removed in a future release. DPA will use the dtype of the inputs instead as this module does not have any parameters.
"""
"""
head_dim:int
head_dim:int
...
@@ -527,18 +611,48 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -527,18 +611,48 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout:float=0.0
attention_dropout:float=0.0
attn_mask_type:AttnMaskType="causal"
attn_mask_type:AttnMaskType="causal"
attn_bias_type:AttnBiasType=None
attn_bias_type:AttnBiasType=None
dtype:DType=jnp.float32
dtype:Optional[DType]=None# Deprecated
dropout_rng_name:str="dropout"
dropout_rng_name:str="dropout"
float32_logits:bool=False
float32_logits:bool=False
qkv_layout:str="bshd_bshd_bshd"
qkv_layout:str="bshd_bshd_bshd"
scale_factor:Optional[float]=None
scale_factor:Optional[float]=None
transpose_batch_sequence:bool=True
transpose_batch_sequence:bool|None=None
window_size:Optional[Tuple[int,int]]=None
window_size:Optional[Tuple[int,int]]=None
max_segments_per_seq:Optional[int]=1
max_segments_per_seq:Optional[int]=1
context_parallel_causal_load_balanced:bool=False
context_parallel_causal_load_balanced:bool=False
context_parallel_axis:str=""
context_parallel_axis:str=""
context_parallel_strategy:str="DEFAULT"
context_parallel_strategy:str="DEFAULT"
context_checkpoint_name:str="context"
context_checkpoint_name:str="context"
softmax_type:str="vanilla"
def__post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
ifself.transpose_batch_sequenceisNone:
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"