Commit 28459058 authored by letaoqin's avatar letaoqin
Browse files

add check code

parent d10f25a0
...@@ -1129,6 +1129,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1129,6 +1129,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part // Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
......
...@@ -1135,6 +1135,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1135,6 +1135,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part // Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment