Commit dccaf0b2 authored by letaoqin's avatar letaoqin
Browse files

change check name to d0s_n_length_stride_

parent 514cee8a
......@@ -721,9 +721,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_nl_ns_lengths_strides_[i].push_back(
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides_[i].push_back(
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
});
is_dropout_ = p_dropout > 0.0; //
......@@ -830,7 +830,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t n_raw_padded_;
// raw data
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride_;
};
// Invoker
......@@ -1039,12 +1039,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(arg.d0s_n_length_stride_[i][1] == 1 &&
arg.d0s_n_length_stride_[i][0] % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
if(arg.d0s_n_length_stride_[i][1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
......
......@@ -658,7 +658,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CGridDesc_M_N c_grid_desc_m_n_;
// raw data
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride_;
};
// Argument
......@@ -708,16 +708,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto& problem_desc = problem_desc_vec[i];
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
static_for<0, NumD0Tensor, 1>{}([&](auto j) {
using D0DataType = remove_cvref_t<tuple_element_t<j.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid(j) = static_cast<const D0DataType*>(p_acc0_biases_vec[i][j]);
// for check
d0s_nl_ns_lengths_strides[j].push_back(
d0s_n_length_stride[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_lengths[j][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides[j].push_back(
d0s_n_length_stride[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[j][NumDimG + NumDimM]);
});
......@@ -859,7 +859,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n,
d0s_nl_ns_lengths_strides});
d0s_n_length_stride});
}
is_dropout_ = p_dropout > 0.0; //
......@@ -1081,14 +1081,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
for(int In = 0; In < NumD0Tensor; In++)
{
if(device_arg.d0s_nl_ns_lengths_strides_[In][1] == 1 &&
device_arg.d0s_nl_ns_lengths_strides_[In][0] %
Acc0BiasTransferSrcScalarPerVector !=
0)
if(device_arg.d0s_n_length_stride_[In][1] == 1 &&
device_arg.d0s_n_length_stride_[In][0] % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0s_nl_ns_lengths_strides_[In][1] != 1 &&
if(device_arg.d0s_n_length_stride_[In][1] != 1 &&
Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment