Commit 9364c387 authored by Anthony Chang's avatar Anthony Chang
Browse files

workaround clang crash for gfx908

(gfx908 only) workaround for compiler crash in fused kernels on mainline #9110; #10738 seems ok
error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
VGPR_32: Cannot scavenge register without an emergency spill slot!"
this fall back to less ideal way of handle NPadding in fused attention kernel
parent 90b771c0
......@@ -144,6 +144,17 @@
// workaround: compiler gnerating inefficient ds_write instructions
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
// (gfx908 only) workaround: compiler crash in fused kernels on mainline #9110; #10738 seems ok
// error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
// VGPR_32: Cannot scavenge register without an emergency spill slot!"
// this fall back to less ideal way of handle NPadding in fused attention kernel
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 1
#else
// for __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 0
#endif // __gfx908__
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
......
......@@ -16,7 +16,8 @@ template <index_t BlockSize,
typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K>
typename ThreadSliceDesc_M_K,
bool IgnoreNaN = false>
struct BlockwiseSoftmax
{
static constexpr auto I0 = Number<0>{};
......@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>;
using ThreadwiseMaxReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false,
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>>::type;
using ThreadwiseSumReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false,
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
......@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
reduce::Add,
false>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer>
......@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM));
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
});
......
......@@ -717,7 +717,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
decltype(thread_slice_desc_m_n)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
......@@ -758,10 +763,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
num_k_block_main_loop);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment