"tests/cpp/operator/test_multi_padding.cu" did not exist on "6b311da2401a0b68bd7775553175763c744c974d"
Unverified Commit 525de6cc authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Update cudnn-frontend to v1.6.1 (#1108)



* update FE to 1.6
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update to 1.6.1-rc for testing
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to fe 1.6.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8e3561bf
Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b
...@@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1(
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); amax_o->set_output(true)
amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
Stats->set_output(true) Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT) .set_data_type(fe::DataType_t::FLOAT)
...@@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1(
dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride);
dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride);
dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride);
amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); amax_dQ->set_output(true)
amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); .set_dim({1, 1, 1, 1})
amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); .set_stride({1, 1, 1, 1})
amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); .set_data_type(fe::DataType_t::FLOAT);
amax_dK->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dV->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dP->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
dO->set_data_type(bwd_tensor_type); dO->set_data_type(bwd_tensor_type);
dQ->set_data_type(bwd_tensor_type); dQ->set_data_type(bwd_tensor_type);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment