Unverified Commit 7ce7dfe5 authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Use jit_fuser for bias-dropout-add fusion (#589)



* Use jit_fuser for bias-dropout-add fusion
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* Use jit_fuser for CP FA kernel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 696ad6c4
...@@ -402,7 +402,7 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst, ...@@ -402,7 +402,7 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
return send_recv_reqs return send_recv_reqs
@torch.jit.script @jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step): def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Flash Attention with context parallelism""" """Merge partial outputs of each step in Flash Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2) softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
...@@ -411,7 +411,7 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe ...@@ -411,7 +411,7 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe
out.add_(out_corrected) out.add_(out_corrected)
@torch.jit.script @jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Flash Attention with context parallelism""" """Merge softmax stats of each step in Flash Attention with context parallelism"""
softmax_lse.exp_() softmax_lse.exp_()
......
...@@ -134,7 +134,7 @@ def get_bias_dropout_add(training: bool) -> Callable: ...@@ -134,7 +134,7 @@ def get_bias_dropout_add(training: bool) -> Callable:
return _bias_dropout_add return _bias_dropout_add
@torch.jit.script @jit_fuser
def bias_dropout_add_fused_train_( def bias_dropout_add_fused_train_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -151,7 +151,7 @@ def bias_dropout_add_fused_train( ...@@ -151,7 +151,7 @@ def bias_dropout_add_fused_train(
return bias_dropout_add_fused_train_(x, bias, residual, prob) return bias_dropout_add_fused_train_(x, bias, residual, prob)
@torch.jit.script @jit_fuser
def bias_dropout_add_fused_inference_( def bias_dropout_add_fused_inference_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment