Commit 2aa9cbee authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light

parents be38f68d 41c659bb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -81,7 +81,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -449,7 +450,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
......@@ -898,7 +899,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -83,7 +83,8 @@ __global__ void
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -457,7 +458,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
......@@ -915,7 +916,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
......@@ -240,13 +241,6 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -576,7 +570,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -628,14 +622,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -1020,7 +1006,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......@@ -1051,14 +1039,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
......@@ -1071,15 +1057,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || c_stride_lowest == 1))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
......@@ -569,7 +570,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -1012,7 +1013,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -52,7 +53,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
......@@ -104,7 +106,7 @@ __global__ void
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -140,7 +142,7 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -573,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -960,7 +962,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm,
......@@ -971,6 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel(
......@@ -995,11 +998,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
}
else
{
......@@ -1025,7 +1034,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -52,7 +53,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
......@@ -104,7 +106,7 @@ __global__ void
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -140,7 +142,7 @@ __global__ void
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
......@@ -573,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
......@@ -967,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm,
......@@ -978,6 +980,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel(
......@@ -1002,11 +1005,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, false>{});
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.p_dropout_ > 0.0)
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, true>{});
else
ave_time = launch_kernel(integral_constant<bool, false>{}, integral_constant<bool, false>{});
}
else
{
......@@ -1032,7 +1041,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -465,7 +466,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
......@@ -938,7 +939,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -471,7 +472,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
......@@ -960,7 +961,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -71,14 +71,6 @@ template <typename InputDataType,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -88,7 +80,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1943,8 +1935,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
......@@ -2086,6 +2078,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
......@@ -2096,8 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf);
}
// block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.Run(Gemm2::b_block_desc_o0_o1_o2_m0_m1_m2_m3,
ygrad_block_buf,
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3,
......@@ -2135,6 +2126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -88,7 +88,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1838,16 +1838,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
}
else
{
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
......@@ -1924,6 +1920,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
......@@ -1939,7 +1936,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
......@@ -1987,17 +1983,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0)
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// gemm dQ
......@@ -2082,6 +2072,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2098,7 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -81,7 +81,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1>
......@@ -1947,56 +1948,57 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
if constexpr(IsDropout){
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1>
......@@ -1863,55 +1864,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
if constexpr(IsDropout){
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read
// gemm dV
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -81,7 +81,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -87,7 +87,7 @@ template <typename FloatAB,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -87,7 +87,7 @@ template <typename FloatAB,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
......@@ -219,7 +219,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, unsigned short>,
bool>::type
check_err(const Range& out, const RefRange& ref, unsigned short atol = 1)
check_integer_err(const Range& out, const RefRange& ref, unsigned short atol)
{
const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size())
......@@ -262,7 +262,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, int32_t>,
bool>::type
check_err(const Range& out, const RefRange& ref, int32_t atol = 1)
check_integer_err(const Range& out, const RefRange& ref, int32_t atol)
{
const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment