Unverified Commit 7d576ed2 authored by BoxiangW's avatar BoxiangW Committed by GitHub
Browse files

Change `norm_factor` into `softmax_scale` and add kwarg into `DotProductAttention ` (#897)



* Add norm_factor arg into DotProductAttention
Signed-off-by: default avatarBoxiang Wang <boxiangw@nvidia.com>

* Change kwarg name from `norm_factor` to `softmax_scale`
Signed-off-by: default avatarBoxiang Wang <boxiangw@nvidia.com>

* Change all norm_factor representation into softmax_scale
Signed-off-by: default avatarBoxiang Wang <boxiangw@nvidia.com>

* Update transformer_engine/pytorch/attention.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update attention.py changing typo
Signed-off-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>

---------
Signed-off-by: default avatarBoxiang Wang <boxiangw@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 6bec91e0
......@@ -1704,14 +1704,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
def __init__(
self,
norm_factor: float,
softmax_scale: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
) -> None:
super().__init__()
self.norm_factor = norm_factor
self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
......@@ -1790,9 +1790,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if is_in_onnx_export_mode() and is_bf16:
matmul_result = matmul_result.bfloat16()
scale = self.norm_factor
scale = self.softmax_scale
if apply_qk_layer_scaling:
scale *= self.layer_number
scale /= self.layer_number
# Raw attention scores. [b * np, sq, sk]
if core_attention_bias_type == "no_bias":
......@@ -1801,7 +1801,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
alpha=scale,
)
elif core_attention_bias_type == "pre_scale_bias":
......@@ -1813,7 +1813,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias).view(-1, output_size[2], output_size[3])
matmul_result /= scale
matmul_result *= scale
elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
if core_attention_bias_type == "post_scale_bias":
......@@ -1826,7 +1826,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
alpha=scale,
)
matmul_result = (matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3])
......@@ -2055,7 +2055,7 @@ class FlashAttention(torch.nn.Module):
def __init__(
self,
norm_factor: float,
softmax_scale: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self",
......@@ -2071,7 +2071,7 @@ class FlashAttention(torch.nn.Module):
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.norm_factor = norm_factor
self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.attention_type = attention_type
......@@ -2212,7 +2212,7 @@ class FlashAttention(torch.nn.Module):
None, None, None, None,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
softmax_scale=self.softmax_scale,
qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic
......@@ -2238,7 +2238,7 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
......@@ -3063,7 +3063,7 @@ class FusedAttention(TransformerEngineBaseModule):
def __init__(
self,
norm_factor: float,
softmax_scale: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self",
......@@ -3072,7 +3072,7 @@ class FusedAttention(TransformerEngineBaseModule):
) -> None:
super().__init__()
self.norm_factor = norm_factor
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_type = attention_type
......@@ -3248,7 +3248,7 @@ class FusedAttention(TransformerEngineBaseModule):
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
softmax_scale=self.softmax_scale,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type,
......@@ -3280,7 +3280,7 @@ class FusedAttention(TransformerEngineBaseModule):
query_layer, key_layer, value_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.softmax_scale,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
......@@ -3366,6 +3366,9 @@ class DotProductAttention(torch.nn.Module):
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`.
Parallelism parameters
----------------------
......@@ -3404,6 +3407,7 @@ class DotProductAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
softmax_scale: Optional[float] = None,
) -> None:
super().__init__()
......@@ -3441,7 +3445,8 @@ class DotProductAttention(torch.nn.Module):
set_all_rng_states(self.rng_states_tracker.get_states())
attention_dropout_ctx = self.rng_states_tracker.fork
norm_factor = math.sqrt(kv_channels)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(kv_channels)
self.device_compute_capability = get_device_compute_capability()
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) \
......@@ -3477,7 +3482,7 @@ class DotProductAttention(torch.nn.Module):
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor,
self.flash_attention = FlashAttention(softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
......@@ -3486,14 +3491,14 @@ class DotProductAttention(torch.nn.Module):
# Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(norm_factor,
self.fused_attention = FusedAttention(softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
softmax_scale, **attn_kwargs, layer_number=layer_number)
def _checkpointed_attention_forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment