Commit e296ee56 authored by ltqin's avatar ltqin
Browse files

fix z pointers empty issue

parent c9915508
......@@ -691,7 +691,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() &&
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0)))
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0) &&
(group_count_ == p_z_vec.size() || p_z_vec.size() == 0) &&
(group_count_ == p_lse_vec.size() || p_lse_vec.size() == 0)))
{
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
......@@ -704,13 +706,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_d0_grid = p_acc0_biases_vec.size() > 0
const auto p_d0_grid = (p_acc0_biases_vec.size() == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases_vec[i])
: nullptr;
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto p_z_grid = (p_z_vec.size() == group_count_)
? static_cast<ZDataType*>(p_z_vec[i])
: nullptr;
const auto p_lse_grid = (p_lse_vec.size() == group_count_)
? static_cast<LSEDataType*>(p_lse_vec[i])
: nullptr;
if(p_lse_grid == nullptr)
{
......@@ -851,7 +857,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_n_length_stride});
}
is_dropout_ = p_dropout > 0.0; //
use_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
......@@ -878,7 +884,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
unsigned long long seed_;
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_;
bool use_dropout_;
bool is_lse_storing_ = true;
};
......@@ -914,7 +920,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
[&](auto has_main_k_block_loop_, auto use_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType,
......@@ -926,7 +932,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
use_dropout_,
is_lse_storing_,
Deterministic>;
......@@ -953,7 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
if(arg.is_dropout_)
if(arg.use_dropout_)
{
if(arg.is_lse_storing_)
{
......@@ -986,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
}
else if(!some_has_main_k_block_loop)
{
if(arg.is_dropout_)
if(arg.use_dropout_)
{
if(arg.is_lse_storing_)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment