"docs/vscode:/vscode.git/clone" did not exist on "c81dddb45c71e630b907f9d84686ecd73b4105c7"
Unverified Commit d793ca17 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add deterministic option in DotProductAttention (#956)



add deterministic option
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 81dd6ad0
...@@ -872,8 +872,9 @@ class TestLayerNormMLP: ...@@ -872,8 +872,9 @@ class TestLayerNormMLP:
@pytest.mark.parametrize("attn_type", ["self", "cross"]) @pytest.mark.parametrize("attn_type", ["self", "cross"])
@pytest.mark.parametrize("mask_type", ["causal", "padding"]) @pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) @pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("deterministic", [True, False])
def test_dot_product_attention( def test_dot_product_attention(
bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic
): ):
""" """
Test DotProductAttention Layer Test DotProductAttention Layer
...@@ -927,6 +928,10 @@ def test_dot_product_attention( ...@@ -927,6 +928,10 @@ def test_dot_product_attention(
attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
if deterministic:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
layer_te = te.DotProductAttention( layer_te = te.DotProductAttention(
num_heads, num_heads,
head_size, head_size,
...@@ -981,6 +986,15 @@ def test_dot_product_attention( ...@@ -981,6 +986,15 @@ def test_dot_product_attention(
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol)
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
if deterministic:
out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad(
layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
)
assert_allclose(out, out2, rtol=1e-12, atol=1e-12)
assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12)
assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12)
assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12)
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
@pytest.mark.parametrize("bs", [1, 2]) @pytest.mark.parametrize("bs", [1, 2])
......
...@@ -586,7 +586,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -586,7 +586,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_sliding_window_length(window_size_left); sdpa_backward_options.set_sliding_window_length(window_size_left);
} }
if (cudnn_runtime_version >= 90000 && sm_arch_ >= 90) { if (cudnn_runtime_version >= 90000) {
sdpa_backward_options.set_deterministic_algorithm(deterministic); sdpa_backward_options.set_deterministic_algorithm(deterministic);
} }
......
...@@ -659,6 +659,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -659,6 +659,7 @@ def fused_attn_bwd_qkvpacked(
qkv_layout: str = "bs3hd", qkv_layout: str = "bs3hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed QKV input""" """Fused Attention BWD for packed QKV input"""
...@@ -715,6 +716,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -715,6 +716,7 @@ def fused_attn_bwd_qkvpacked(
bias_type, bias_type,
attn_mask_type, attn_mask_type,
int(qkv_dtype), int(qkv_dtype),
deterministic,
) )
return dqkv, dbias return dqkv, dbias
...@@ -855,6 +857,7 @@ def fused_attn_bwd_kvpacked( ...@@ -855,6 +857,7 @@ def fused_attn_bwd_kvpacked(
qkv_layout: str = "bshd_bs2hd", qkv_layout: str = "bshd_bs2hd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input""" """Fused Attention BWD for packed KV input"""
...@@ -921,6 +924,7 @@ def fused_attn_bwd_kvpacked( ...@@ -921,6 +924,7 @@ def fused_attn_bwd_kvpacked(
bias_type, bias_type,
attn_mask_type, attn_mask_type,
int(qkv_dtype), int(qkv_dtype),
deterministic,
) )
return dq, dkv, dbias return dq, dkv, dbias
...@@ -1061,6 +1065,7 @@ def fused_attn_bwd( ...@@ -1061,6 +1065,7 @@ def fused_attn_bwd(
qkv_layout: str = "bshd_bshd_bshd", qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input""" """Fused Attention BWD for packed KV input"""
...@@ -1130,6 +1135,7 @@ def fused_attn_bwd( ...@@ -1130,6 +1135,7 @@ def fused_attn_bwd(
bias_type, bias_type,
attn_mask_type, attn_mask_type,
int(qkv_dtype), int(qkv_dtype),
deterministic,
) )
return dq, dk, dv, dbias return dq, dk, dv, dbias
......
...@@ -708,7 +708,8 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -708,7 +708,8 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
int64_t b, int64_t h, int64_t d, int64_t total_seqs, int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, float attn_scale, float p_dropout, int64_t max_seqlen, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type, const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, int64_t qkv_type) { const std::string &attn_mask_type, int64_t qkv_type,
bool deterministic) {
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) { if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape(); auto bias_shape = dBias->shape();
...@@ -759,22 +760,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -759,22 +760,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), nvte_fused_attn_bwd_qkvpacked(
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); deterministic, workspace.data(), QKV.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), nvte_fused_attn_bwd_qkvpacked(
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
attn_mask_type_enum, -1, -1, true, workspace.data(), QKV.stream()); deterministic, workspace.data(), QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -884,7 +885,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -884,7 +885,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout, float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type, const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) { int64_t qkv_type, bool deterministic) {
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) { if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape(); auto bias_shape = dBias->shape();
...@@ -945,7 +946,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -945,7 +946,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream()); -1, -1, deterministic, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -957,7 +958,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -957,7 +958,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, true, workspace.data(), Q.stream()); -1, -1, deterministic, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1086,7 +1087,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1086,7 +1087,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout, float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type, const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) { int64_t qkv_type, bool deterministic) {
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) { if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape(); auto bias_shape = dBias->shape();
...@@ -1149,7 +1150,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1149,7 +1150,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1161,7 +1162,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1161,7 +1162,7 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, true, workspace.data(), Q.stream()); attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1657,7 +1658,8 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) ...@@ -1657,7 +1658,8 @@ PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
.Outputs({"dQKV", paddle::Optional("dBias")}) .Outputs({"dQKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"deterministic: bool"})
.SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked));
...@@ -1682,7 +1684,8 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked) ...@@ -1682,7 +1684,8 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"deterministic: bool"})
.SetInplaceMap({{"_dQ", "dQ"}, .SetInplaceMap({{"_dQ", "dQ"},
{"_dKV", "dKV"}, {"_dKV", "dKV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}}) {paddle::Optional("_dBias"), paddle::Optional("dBias")}})
...@@ -1708,7 +1711,7 @@ PD_BUILD_OP(te_fused_attn_bwd) ...@@ -1708,7 +1711,7 @@ PD_BUILD_OP(te_fused_attn_bwd)
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t"}) "qkv_type: int64_t", "deterministic: bool"})
.SetInplaceMap({{"_dQ", "dQ"}, .SetInplaceMap({{"_dQ", "dQ"},
{"_dK", "dK"}, {"_dK", "dK"},
{"_dV", "dV"}, {"_dV", "dV"},
......
...@@ -152,6 +152,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -152,6 +152,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
is_training, is_training,
deterministic,
fused_attention_backend, fused_attention_backend,
): ):
"""Forward function for FusedAttention with packed QKV input""" """Forward function for FusedAttention with packed QKV input"""
...@@ -180,6 +181,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -180,6 +181,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = fused_attention_backend
return out return out
...@@ -204,6 +206,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -204,6 +206,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.deterministic,
) )
# if no_bias, return dqkv # if no_bias, return dqkv
...@@ -234,6 +237,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): ...@@ -234,6 +237,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
is_training, is_training,
deterministic,
fused_attention_backend, fused_attention_backend,
): ):
"""Forward function for FusedAttention with packed KV input""" """Forward function for FusedAttention with packed KV input"""
...@@ -266,6 +270,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): ...@@ -266,6 +270,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = fused_attention_backend
return out return out
...@@ -293,6 +298,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): ...@@ -293,6 +298,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.deterministic,
) )
# if no_bias, return dq, dkv # if no_bias, return dq, dkv
...@@ -324,6 +330,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer): ...@@ -324,6 +330,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer):
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
is_training, is_training,
deterministic,
fused_attention_backend, fused_attention_backend,
): ):
"""Forward function for FusedAttention with separate Q, K, V tensors""" """Forward function for FusedAttention with separate Q, K, V tensors"""
...@@ -357,6 +364,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer): ...@@ -357,6 +364,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = fused_attention_backend
return out return out
...@@ -385,6 +393,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer): ...@@ -385,6 +393,7 @@ class FusedAttnFunc(paddle.autograd.PyLayer):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.deterministic,
) )
# if no_bias, return dq, dk, dv # if no_bias, return dq, dk, dv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type == "no_bias":
...@@ -404,6 +413,12 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -404,6 +413,12 @@ class DotProductAttention(paddle.nn.Layer):
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`. :attr:`attn_mask_type` is set to `"causal"`.
.. warning::
Fused attention backward uses a non-deterministic algorithm when workspace
optimization is not enabled. To use a deterministic algorithm, set the
environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`
Parameters Parameters
---------- ----------
num_attention_heads: int num_attention_heads: int
...@@ -458,6 +473,29 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -458,6 +473,29 @@ class DotProductAttention(paddle.nn.Layer):
self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
# To use the workspace optimization path for determinism, please
# set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
# and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
cudnn_version = paddle.get_cudnn_version()
if 8905 <= cudnn_version < 9000:
if self.deterministic:
# workspace optimization path is deterministic
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
# or when bias gradient needs to be computed
# - n: enables workspace optimization when required workspace is <= n bytes
# - -1: enables workspace optimization always
# - 0: disables workspace optimization always
if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
if not self.use_fused_attention and backend == "transformer_engine": if not self.use_fused_attention and backend == "transformer_engine":
warnings.warn("Fused attention is not enabled, falling back to Paddle backend") warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = "paddle" self.backend = "paddle"
...@@ -603,6 +641,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -603,6 +641,7 @@ class DotProductAttention(paddle.nn.Layer):
core_attention_bias_type, core_attention_bias_type,
self.attn_mask_type, self.attn_mask_type,
self.training, self.training,
self.deterministic,
self.fused_attention_backend, self.fused_attention_backend,
) )
elif self.attention_type == "cross": elif self.attention_type == "cross":
...@@ -637,6 +676,7 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -637,6 +676,7 @@ class DotProductAttention(paddle.nn.Layer):
core_attention_bias_type, core_attention_bias_type,
self.attn_mask_type, self.attn_mask_type,
self.training, self.training,
self.deterministic,
self.fused_attention_backend, self.fused_attention_backend,
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment