"tests/benchmarks/vscode:/vscode.git/clone" did not exist on "76066b6dd4b17f35edeac05bd490614d700d3177"
Unverified Commit 20be25a3 authored by Md Fahim Faysal Khan's avatar Md Fahim Faysal Khan Committed by GitHub
Browse files

[ TE-JAX ] Expose cp_strategy argument to DPA api (#2090)



* added cp strategy arg to DPA api
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

* converted DPA cp_strategy to string
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

---------
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
parent f1b18ed0
......@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..attention import CPStrategy
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context"
@nn.compact
......@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_kvpacked():
......@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_separate():
......@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
else:
......@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
......@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context"
@nn.compact
......@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor = self.scale_factor
del self.scale_factor
# case-insensitive mapping for context parallel strategy
cp_strategy_map = {
"DEFAULT": CPStrategy.DEFAULT,
"ALL_GATHER": CPStrategy.ALL_GATHER,
"ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling
"RING": CPStrategy.RING,
}
strategy_key = self.context_parallel_strategy.upper()
if strategy_key in cp_strategy_map:
context_parallel_strategy = cp_strategy_map[strategy_key]
else:
valid_strategies = list(cp_strategy_map.keys())
raise ValueError(
f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
f"Valid options are: {valid_strategies} (case insensitive)"
)
if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout.is_qkvpacked():
......@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)(
query,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment