Unverified Commit 1ec28a2c authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

ulysses enabling in native attention path (#12563)



* ulysses enabling in native attention path
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* address review comment
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* add supports_context_parallel for native attention
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* update templated attention
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

---------
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent de6173c6
...@@ -649,6 +649,86 @@ def _( ...@@ -649,6 +649,86 @@ def _(
# ===== Helper functions to use attention backends with templated CP autograd functions ===== # ===== Helper functions to use attention backends with templated CP autograd functions =====
def _native_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# Native attention does not return_lse
if return_lse:
raise ValueError("Native attention does not support return_lse=True")
# used for backward pass
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.attn_mask = attn_mask
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.enable_gqa = enable_gqa
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
def _native_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value = ctx.saved_tensors
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)
grad_value = grad_value_t.permute(0, 2, 1, 3)
return grad_query, grad_key, grad_value
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration: # forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
...@@ -1523,6 +1603,7 @@ def _native_flex_attention( ...@@ -1523,6 +1603,7 @@ def _native_flex_attention(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE, AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape], constraints=[_check_device, _check_shape],
supports_context_parallel=True,
) )
def _native_attention( def _native_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -1538,18 +1619,35 @@ def _native_attention( ...@@ -1538,18 +1619,35 @@ def _native_attention(
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse: if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.") raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) if _parallel_config is None:
out = torch.nn.functional.scaled_dot_product_attention( query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query=query, out = torch.nn.functional.scaled_dot_product_attention(
key=key, query=query,
value=value, key=key,
attn_mask=attn_mask, value=value,
dropout_p=dropout_p, attn_mask=attn_mask,
is_causal=is_causal, dropout_p=dropout_p,
scale=scale, is_causal=is_causal,
enable_gqa=enable_gqa, scale=scale,
) enable_gqa=enable_gqa,
out = out.permute(0, 2, 1, 3) )
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op=_native_attention_forward_op,
backward_op=_native_attention_backward_op,
_parallel_config=_parallel_config,
)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment