Commit 9364c387 authored by Anthony Chang's avatar Anthony Chang
Browse files

workaround clang crash for gfx908

(gfx908 only) workaround for compiler crash in fused kernels on mainline #9110; #10738 seems ok
error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
VGPR_32: Cannot scavenge register without an emergency spill slot!"
this fall back to less ideal way of handle NPadding in fused attention kernel
parent 90b771c0
...@@ -144,6 +144,17 @@ ...@@ -144,6 +144,17 @@
// workaround: compiler gnerating inefficient ds_write instructions // workaround: compiler gnerating inefficient ds_write instructions
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
// (gfx908 only) workaround: compiler crash in fused kernels on mainline #9110; #10738 seems ok
// error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
// VGPR_32: Cannot scavenge register without an emergency spill slot!"
// this fall back to less ideal way of handle NPadding in fused attention kernel
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 1
#else
// for __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 0
#endif // __gfx908__
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some // workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
......
...@@ -16,7 +16,8 @@ template <index_t BlockSize, ...@@ -16,7 +16,8 @@ template <index_t BlockSize,
typename AccDataType, typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K, typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K> typename ThreadSliceDesc_M_K,
bool IgnoreNaN = false>
struct BlockwiseSoftmax struct BlockwiseSoftmax
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax ...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
using ThreadSliceDesc_M = decltype( using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType, using ThreadwiseMaxReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K, ThreadSliceDesc_M_K,
ThreadSliceDesc_M, ThreadSliceDesc_M,
reduce::Max, reduce::Max,
false>; false,
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>>::type;
using ThreadwiseSumReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false,
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax ...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
reduce::Add, reduce::Add,
false>; false>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>; using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer> template <typename CThreadBuffer, typename WorkspaceBuffer>
...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax ...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM)); in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
}); });
}); });
......
...@@ -717,7 +717,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -717,7 +717,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc, FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor), decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>{};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
...@@ -758,10 +763,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -758,10 +763,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
num_k_block_main_loop); num_k_block_main_loop);
// Acc0 elementwise Op // Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run( ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]); acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
}); });
#endif
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment