Commit 9096e2af authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-bhalf-test

parents ac407086 8453af0c
...@@ -25,6 +25,7 @@ Kernel outputs: ...@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define USING_HD32 0
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#else
// 2nd template
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
...@@ -125,6 +200,7 @@ using DeviceGemmInstance = ...@@ -125,6 +200,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -151,6 +227,7 @@ using DeviceGemmInstance = ...@@ -151,6 +227,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...@@ -175,7 +175,7 @@ int run(int argc, char* argv[]) ...@@ -175,7 +175,7 @@ int run(int argc, char* argv[])
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
...@@ -228,6 +228,44 @@ int run(int argc, char* argv[]) ...@@ -228,6 +228,44 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor
argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -56,7 +56,8 @@ int run(int argc, char* argv[]) ...@@ -56,7 +56,8 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<void*> p_z; std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
...@@ -221,6 +222,7 @@ int run(int argc, char* argv[]) ...@@ -221,6 +222,7 @@ int run(int argc, char* argv[])
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -233,12 +235,13 @@ int run(int argc, char* argv[]) ...@@ -233,12 +235,13 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = auto argument =
gemm.MakeArgument(p_a, gemm.MakeArgument(p_a,
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
p_z, p_z_nullptr,
p_lse, p_lse,
{}, // p_acc0_biases {}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
...@@ -252,7 +255,6 @@ int run(int argc, char* argv[]) ...@@ -252,7 +255,6 @@ int run(int argc, char* argv[])
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -277,6 +279,31 @@ int run(int argc, char* argv[]) ...@@ -277,6 +279,31 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
argument =
gemm.MakeArgument(p_a,
p_b0,
p_b1,
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1_m_n_k_o[i][0];
......
...@@ -118,7 +118,7 @@ ...@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose // experimental feature: in-regsiter sub-dword transpose
......
...@@ -204,6 +204,7 @@ template <index_t NumDimG, ...@@ -204,6 +204,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
......
...@@ -44,7 +44,8 @@ template <typename GridwiseGemm, ...@@ -44,7 +44,8 @@ template <typename GridwiseGemm,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout> bool IsDropout,
bool IsLseStoring>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -100,13 +101,13 @@ __global__ void ...@@ -100,13 +101,13 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid + lse_batch_offset, nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -596,6 +597,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -596,6 +597,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
} }
void Print() const void Print() const
...@@ -669,6 +676,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -669,6 +676,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
bool is_dropout_; bool is_dropout_;
bool is_lse_storing_ = true;
}; };
// Invoker // Invoker
...@@ -692,7 +701,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -692,7 +701,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel = [&](auto has_main_k_block_loop_,
auto is_dropout_,
auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
...@@ -715,7 +726,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -715,7 +726,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_>; is_dropout_,
is_lse_storing_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -754,29 +766,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -754,29 +766,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
else else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
}
return ave_time; return ave_time;
} }
......
...@@ -51,6 +51,7 @@ template <typename DataType, ...@@ -51,6 +51,7 @@ template <typename DataType,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert(Sum_M % MPerXdl == 0, ""); static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t GemmNWave = 2; static constexpr index_t GemmNWave = Free0_N / Gemm2NXdlPerWave / MPerXdl;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl; static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMLoop = Free1_M / Sum_M; static constexpr index_t GemmMLoop = Free1_M / Sum_M;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
......
...@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K) // if(Gemm1N != K)
{ //{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; // std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; // return false;
} //}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment