Unverified Commit 2aee0591 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650)



* Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* make d_out contiguous for bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove cudnnDestroy to let torch handle it
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce163f9e
Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987 Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06
...@@ -152,11 +152,6 @@ class cudnnExecutionPlanManager { ...@@ -152,11 +152,6 @@ class cudnnExecutionPlanManager {
} }
~cudnnExecutionPlanManager() { ~cudnnExecutionPlanManager() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] {
if (handle_ != nullptr) {
cudnnDestroy(handle_);
}});
} }
private: private:
......
...@@ -1823,6 +1823,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1823,6 +1823,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
d_out = d_out.contiguous()
qkv, out, cu_seqlens = ctx.saved_tensors qkv, out, cu_seqlens = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
...@@ -1892,6 +1893,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1892,6 +1893,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
d_out = d_out.contiguous()
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
...@@ -1973,6 +1975,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1973,6 +1975,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
d_out = d_out.contiguous()
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment