Unverified Commit 2aee0591 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650)



* Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* make d_out contiguous for bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove cudnnDestroy to let torch handle it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce163f9e
Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987
Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06
......@@ -152,11 +152,6 @@ class cudnnExecutionPlanManager {
}
~cudnnExecutionPlanManager() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] {
if (handle_ != nullptr) {
cudnnDestroy(handle_);
}});
}
private:
......
......@@ -1823,6 +1823,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
qkv, out, cu_seqlens = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
......@@ -1892,6 +1893,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
......@@ -1973,6 +1975,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
d_out = d_out.contiguous()
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment