Unverified Commit 238df4ce authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Initialize output tensors to 0 for THD (temporary) (#1009)



* initialize output tensors to 0 for THD while waiting for cuDNN bug fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move fill_() to F16 loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused_attn_bwd()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* correct typo in check_set_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use nvtx3 instead
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e1c6d218
...@@ -593,6 +593,9 @@ def fused_attn_fwd_qkvpacked( ...@@ -593,6 +593,9 @@ def fused_attn_fwd_qkvpacked(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
else: else:
...@@ -676,13 +679,19 @@ def fused_attn_bwd_qkvpacked( ...@@ -676,13 +679,19 @@ def fused_attn_bwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else: else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias": if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else: else:
dbias = None dbias = None
# execute kernel # execute kernel
...@@ -772,6 +781,9 @@ def fused_attn_fwd_kvpacked( ...@@ -772,6 +781,9 @@ def fused_attn_fwd_kvpacked(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else: else:
...@@ -867,6 +879,9 @@ def fused_attn_bwd_kvpacked( ...@@ -867,6 +879,9 @@ def fused_attn_bwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
...@@ -874,7 +889,10 @@ def fused_attn_bwd_kvpacked( ...@@ -874,7 +889,10 @@ def fused_attn_bwd_kvpacked(
dq = paddle.empty(shape=q.shape, dtype=q.dtype) dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias": if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else: else:
dbias = None dbias = None
# execute kernel # execute kernel
...@@ -970,6 +988,9 @@ def fused_attn_fwd( ...@@ -970,6 +988,9 @@ def fused_attn_fwd(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else: else:
...@@ -1065,6 +1086,9 @@ def fused_attn_bwd( ...@@ -1065,6 +1086,9 @@ def fused_attn_bwd(
fused_attention_backend != FusedAttnBackend["No_Backend"] fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination." ), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero: if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
...@@ -1074,7 +1098,10 @@ def fused_attn_bwd( ...@@ -1074,7 +1098,10 @@ def fused_attn_bwd(
dk = paddle.empty(shape=k.shape, dtype=k.dtype) dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype) dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias": if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else: else:
dbias = None dbias = None
# execute kernel # execute kernel
......
...@@ -3068,7 +3068,7 @@ def check_set_window_size( ...@@ -3068,7 +3068,7 @@ def check_set_window_size(
warnings.warn( warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
) )
elif orig_window_size[0] < 0 or orig_window_size[0] < 0: elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
assert False, ( assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
) )
......
...@@ -127,6 +127,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -127,6 +127,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr); scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_QKV = te_QKV =
makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr);
...@@ -288,6 +291,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -288,6 +291,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
amax_dQKV.value().data_ptr(), amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr); scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQKV.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_QKV = te_QKV =
makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr);
...@@ -328,6 +334,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -328,6 +334,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
options); options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
} }
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
} }
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
...@@ -427,6 +436,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -427,6 +436,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr); scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_KV = te_KV =
...@@ -614,6 +626,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -614,6 +626,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
amax_dQKV.value().data_ptr(), amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr); scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0);
dKV.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_KV = te_KV =
...@@ -684,6 +700,9 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -684,6 +700,9 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
options); options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
} }
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
} }
// create workspace // create workspace
...@@ -774,6 +793,9 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -774,6 +793,9 @@ std::vector<at::Tensor> fused_attn_fwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr); scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
...@@ -1037,6 +1059,11 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1037,6 +1059,11 @@ std::vector<at::Tensor> fused_attn_bwd(
makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr); scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0);
dK.fill_(0);
dV.fill_(0);
}
// BF16 or FP16 // BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
...@@ -1109,6 +1136,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1109,6 +1136,9 @@ std::vector<at::Tensor> fused_attn_bwd(
options); options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
} }
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
} }
// create workspace // create workspace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment