Unverified Commit 238df4ce authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Initialize output tensors to 0 for THD (temporary) (#1009)



* initialize output tensors to 0 for THD while waiting for cuDNN bug fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move fill_() to F16 loop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused_attn_bwd()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* correct typo in check_set_window_size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use nvtx3 instead
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e1c6d218
......@@ -593,6 +593,9 @@ def fused_attn_fwd_qkvpacked(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
else:
......@@ -676,13 +679,19 @@ def fused_attn_bwd_qkvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = None
# execute kernel
......@@ -772,6 +781,9 @@ def fused_attn_fwd_kvpacked(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
......@@ -867,6 +879,9 @@ def fused_attn_bwd_kvpacked(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
......@@ -874,7 +889,10 @@ def fused_attn_bwd_kvpacked(
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
......@@ -970,6 +988,9 @@ def fused_attn_fwd(
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
......@@ -1065,6 +1086,9 @@ def fused_attn_bwd(
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
......@@ -1074,7 +1098,10 @@ def fused_attn_bwd(
dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
......
......@@ -3068,7 +3068,7 @@ def check_set_window_size(
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[0] < 0:
elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
......
......@@ -127,6 +127,9 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16
te_QKV =
makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr);
......@@ -288,6 +291,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQKV.fill_(0);
}
// BF16 or FP16
te_QKV =
makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr);
......@@ -328,6 +334,9 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
options);
te_dBias = makeTransformerEngineTensor(dBias);
}
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
}
// create cu_seqlens tensorwrappers
......@@ -427,6 +436,9 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_KV =
......@@ -614,6 +626,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0);
dKV.fill_(0);
}
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_KV =
......@@ -684,6 +700,9 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
options);
te_dBias = makeTransformerEngineTensor(dBias);
}
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
}
// create workspace
......@@ -774,6 +793,9 @@ std::vector<at::Tensor> fused_attn_fwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0);
}
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
......@@ -1037,6 +1059,11 @@ std::vector<at::Tensor> fused_attn_bwd(
makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(),
scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0);
dK.fill_(0);
dV.fill_(0);
}
// BF16 or FP16
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
......@@ -1109,6 +1136,9 @@ std::vector<at::Tensor> fused_attn_bwd(
options);
te_dBias = makeTransformerEngineTensor(dBias);
}
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dBias.fill_(0);
}
}
// create workspace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment