Commit e296ee56 authored by ltqin's avatar ltqin
Browse files

fix z pointers empty issue

parent c9915508
...@@ -691,7 +691,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -691,7 +691,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() && group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() &&
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0))) (group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0) &&
(group_count_ == p_z_vec.size() || p_z_vec.size() == 0) &&
(group_count_ == p_lse_vec.size() || p_lse_vec.size() == 0)))
{ {
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
...@@ -704,13 +706,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -704,13 +706,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_d0_grid = p_acc0_biases_vec.size() > 0 const auto p_d0_grid = (p_acc0_biases_vec.size() == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases_vec[i]) ? static_cast<const D0DataType*>(p_acc0_biases_vec[i])
: nullptr; : nullptr;
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]); const auto p_z_grid = (p_z_vec.size() == group_count_)
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); ? static_cast<ZDataType*>(p_z_vec[i])
: nullptr;
const auto p_lse_grid = (p_lse_vec.size() == group_count_)
? static_cast<LSEDataType*>(p_lse_vec[i])
: nullptr;
if(p_lse_grid == nullptr) if(p_lse_grid == nullptr)
{ {
...@@ -851,7 +857,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -851,7 +857,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_n_length_stride}); d0_n_length_stride});
} }
is_dropout_ = p_dropout > 0.0; // use_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
...@@ -878,7 +884,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -878,7 +884,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool use_dropout_;
bool is_lse_storing_ = true; bool is_lse_storing_ = true;
}; };
...@@ -914,7 +920,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -914,7 +920,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto use_dropout_, auto is_lse_storing_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType, D0DataType,
...@@ -926,7 +932,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -926,7 +932,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_, use_dropout_,
is_lse_storing_, is_lse_storing_,
Deterministic>; Deterministic>;
...@@ -953,7 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -953,7 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.use_dropout_)
{ {
if(arg.is_lse_storing_) if(arg.is_lse_storing_)
{ {
...@@ -986,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -986,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.use_dropout_)
{ {
if(arg.is_lse_storing_) if(arg.is_lse_storing_)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment