Unverified Commit e10997bf authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch/C] Fix compiling warnings and backend selection logic for fused attention (#559)



* fix backend selection for sm80
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix compiling warnings in sdpa flash
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add nvte error messages
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_CHECK_CUDNN_FE for error messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable pylint bare-except
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent bd0873af
......@@ -131,11 +131,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| (cudnn_runtime_version >= 8907))
&& ((head_dim <= 128) && (head_dim % 8 == 0))
&& ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| ((cudnn_runtime_version >= 8906 && sm_arch_ == 90)
|| ((cudnn_runtime_version >= 8906)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK)
|| bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)))
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& sm_arch_ == 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90))))
&& ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK
......
......@@ -93,7 +93,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_fprop_cache;
static thread_local CacheType sdpa_f16_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
......@@ -144,23 +144,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_attributes
scaled_dot_product_flash_attention_options;
scaled_dot_product_flash_attention_options =
fe::graph::Scaled_dot_product_flash_attention_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_options.set_alibi_mask(is_alibi);
sdpa_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_options.set_bias(bias);
sdpa_options.set_bias(bias);
}
if (is_padding) {
......@@ -174,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_options.set_padding_mask(is_padding)
sdpa_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}
......@@ -190,12 +188,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
scaled_dot_product_flash_attention_options.set_dropout(
sdpa_options.set_dropout(
dropout_probability, dropout_seed, dropout_offset);
}
auto [O, Stats] = mha_graph->scaled_dot_product_flash_attention(
Q, K, V, scaled_dot_product_flash_attention_options);
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
......@@ -224,11 +221,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::make_tuple(nullptr), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
......@@ -240,7 +237,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto [mha_graph, Q, K, V, attn_scale, O, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_fprop_cache, descriptor);
sdpa_f16_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
......@@ -286,7 +283,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
mha_graph->execute(handle, variant_pack, workspace);
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
......@@ -342,7 +339,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_flash_f16_bprop_cache;
static thread_local CacheType sdpa_f16_bprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
......@@ -409,15 +406,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
fe::graph::Scaled_dot_product_flash_attention_backward_attributes
scaled_dot_product_flash_attention_backward_options;
scaled_dot_product_flash_attention_backward_options =
fe::graph::Scaled_dot_product_flash_attention_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
fe::graph::SDPA_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("flash_attention_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
scaled_dot_product_flash_attention_backward_options.set_alibi_mask(is_alibi);
sdpa_backward_options.set_alibi_mask(is_alibi);
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
......@@ -428,8 +423,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_name("dBias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
scaled_dot_product_flash_attention_backward_options.set_bias(bias);
scaled_dot_product_flash_attention_backward_options.set_dbias(dBias);
sdpa_backward_options.set_bias(bias);
sdpa_backward_options.set_dbias(dBias);
}
if (is_padding) {
......@@ -443,7 +438,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_backward_options.set_padding_mask(is_padding)
sdpa_backward_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}
......@@ -459,12 +454,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
scaled_dot_product_flash_attention_backward_options.set_dropout(
sdpa_backward_options.set_dropout(
dropout_probability, dropout_seed, dropout_offset);
}
auto [dQ, dK, dV] = mha_graph->scaled_dot_product_flash_attention_backward(
q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options);
auto [dQ, dK, dV] = mha_graph->sdpa_backward(
q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true)
.set_dim({b, h, s_q, d})
......@@ -497,11 +492,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::make_tuple(nullptr), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
mha_graph->validate();
mha_graph->build_operation_graph(handle);
mha_graph->create_execution_plans({fe::HeurMode_t::A});
mha_graph->check_support(handle);
mha_graph->build_plans(handle);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
......@@ -513,7 +508,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_flash_f16_bprop_cache, descriptor);
sdpa_f16_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
......@@ -562,7 +557,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
mha_graph->execute(handle, variant_pack, workspace);
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
......
......@@ -64,6 +64,19 @@
} \
} while (false)
#define NVTE_CHECK_CUDNN_FE(expr) \
do { \
const auto error = (expr); \
if (error.is_bad()) { \
NVTE_ERROR("cuDNN Error: ", \
error.err_msg, \
". " \
"For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \
} \
} while (false)
#define NVTE_CHECK_NVRTC(expr) \
do { \
const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment