Unverified Commit af7b2b44 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Expose THD format to the flax module (#1480)



* Expose THD to flex MHA module
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfbf4dde
...@@ -645,6 +645,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -645,6 +645,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[4] = seed_sharding
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl, config=config) impl = partial(FusedAttnFwdPrimitive.impl, config=config)
...@@ -1042,7 +1044,10 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1042,7 +1044,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl( def sharded_impl(
......
...@@ -24,7 +24,7 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -24,7 +24,7 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn from ..attention import fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
...@@ -267,6 +267,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -267,6 +267,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
...@@ -276,7 +277,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -276,7 +277,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
query: Array, query: Array,
key: Array, key: Array,
value: Array, value: Array,
mask: Optional[Array] = None, sequence_descriptor: Optional[SequenceDescriptor] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
*, *,
dropout_rng: Optional[PRNGKey] = None, dropout_rng: Optional[PRNGKey] = None,
...@@ -293,8 +294,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -293,8 +294,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
# TODO(rewang): integrate THD format if self.qkv_layout.is_qkvpacked():
if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat """qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d] query: qkvpacked tensor, shape = [..., 3, h, d]
key: ignore key: ignore
...@@ -306,7 +306,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -306,7 +306,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
x = fused_attn( x = fused_attn(
(qkv_packed,), (qkv_packed,),
bias, bias,
mask, sequence_descriptor,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
...@@ -315,10 +315,11 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -315,10 +315,11 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
) )
elif self.qkv_layout == QKVLayout.BSHD_BS2HD: elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat """kvpacked format, treat
query: query tensor, shape = [..., h, d] query: query tensor, shape = [..., h, d]
key: kvpacked tensor, shape = [..., 2, h, d] key: kvpacked tensor, shape = [..., 2, h, d]
...@@ -331,7 +332,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -331,7 +332,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
x = fused_attn( x = fused_attn(
(query, kv_packed), (query, kv_packed),
bias, bias,
mask, sequence_descriptor,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
...@@ -340,10 +341,11 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -340,10 +341,11 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
) )
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
key = key.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3])
...@@ -351,7 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -351,7 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
x = fused_attn( x = fused_attn(
(query, key, value), (query, key, value),
bias, bias,
mask, sequence_descriptor,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
...@@ -360,6 +362,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -360,6 +362,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
) )
...@@ -437,6 +440,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -437,6 +440,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...@@ -451,13 +456,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -451,13 +456,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout: str, default = 'bshd_bshd_bshd' qkv_layout: str, default = 'bshd_bshd_bshd'
Specifies the dimensional layout format for the query, key, and value tensors in __call__(). Specifies the dimensional layout format for the query, key, and value tensors in __call__().
It indicates how the inputs are processed. It indicates how the inputs are processed.
Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
* bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
key and value arguments in :attr:`__call__()` are ignored in this layout. key and value arguments in :attr:`__call__()` are ignored in this layout.
* bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored. tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
* bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d]. * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
* t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
sequences to be packed in a batch, also known as sequence packing.
Explanation of denotations: Explanation of denotations:
...@@ -476,6 +483,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -476,6 +483,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window. Sliding window size. The default value is no sliding window.
max_segments_per_seq: Optional[int], default = 1
The maximum number of segments per sequence, also used for THD format (sequence packing).
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
...@@ -502,6 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -502,6 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
...@@ -511,10 +521,11 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -511,10 +521,11 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
query: Array, query: Array,
key: Array, key: Array,
value: Array, value: Array,
mask: Optional[Array] = None, sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
*, *,
deterministic: bool = False, deterministic: bool = False,
mask: Optional[Union[SequenceDescriptor, Array]] = None,
) -> Array: ) -> Array:
""" """
Parameters Parameters
...@@ -542,6 +553,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -542,6 +553,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Output tensors. Output tensors.
""" """
if mask is not None:
if sequence_descriptor is not None:
raise ValueError(
"sequence_descriptor and mask cannot be provided at the same time."
)
warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
sequence_descriptor = mask
del mask
# For internal API, we use enum to maintain # For internal API, we use enum to maintain
if self.attn_bias_type is None: if self.attn_bias_type is None:
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
...@@ -604,16 +624,18 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -604,16 +624,18 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
if not use_fused_attn: if not use_fused_attn:
# unfused attention only supports splitted query, key, value # unfused attention only supports splitted query, key, value
if qkv_layout == QKVLayout.BS3HD: if qkv_layout.is_qkvpacked():
query, key, value = jnp.split(query, [1, 2], axis=-3) query, key, value = jnp.split(query, [1, 2], axis=-3)
query, key, value = map( query, key, value = map(
functools.partial(jnp.squeeze, axis=-3), [query, key, value] functools.partial(jnp.squeeze, axis=-3), [query, key, value]
) )
elif qkv_layout == QKVLayout.BSHD_BS2HD: elif qkv_layout.is_kvpacked():
key, value = jnp.split(key, [1], axis=-3) key, value = jnp.split(key, [1], axis=-3)
key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
else: else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD assert qkv_layout.is_separate()
assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray)
x = _UnfusedDotProductAttention( x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
...@@ -625,7 +647,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -625,7 +647,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size, window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) )(
query,
key,
value,
sequence_descriptor,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic,
)
else: else:
x = _FusedDotProductAttention( x = _FusedDotProductAttention(
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
...@@ -637,9 +667,18 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -637,9 +667,18 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=self.window_size, window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) )(
query,
key,
value,
sequence_descriptor,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic,
)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment