"...composable_kernel.git" did not exist on "756a76172780dee7396a0f288d52eb63c2c0f8fc"
Commit 9096e2af authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-bhalf-test

parents ac407086 8453af0c
......@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_HD32 0
#include <iostream>
#include <numeric>
......@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#else
// 2nd template
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
......@@ -125,6 +200,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -151,6 +227,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool time_kernel = true;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
......@@ -175,7 +175,7 @@ int run(int argc, char* argv[])
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
......@@ -228,6 +228,44 @@ int run(int argc, char* argv[])
if(do_verification)
{
// run for storing z tensor
argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool time_kernel = true;
bool input_permute = false;
bool output_permute = true;
......@@ -56,7 +56,8 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
std::vector<void*> p_z;
std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o;
......@@ -221,6 +222,7 @@ int run(int argc, char* argv[])
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
}
......@@ -233,12 +235,13 @@ int run(int argc, char* argv[])
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(p_a,
p_b0,
p_b1,
p_c,
p_z,
p_z_nullptr,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
......@@ -252,7 +255,6 @@ int run(int argc, char* argv[])
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -277,6 +279,31 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
argument =
gemm.MakeArgument(p_a,
p_b0,
p_b1,
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++)
{
const int& G0 = g0_g1_m_n_k_o[i][0];
......
......@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
......
......@@ -204,6 +204,7 @@ template <index_t NumDimG,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......
......@@ -44,7 +44,8 @@ template <typename GridwiseGemm,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop,
bool IsDropout>
bool IsDropout,
bool IsLseStoring>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -100,13 +101,13 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid + lse_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -596,6 +597,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
}
void Print() const
......@@ -669,6 +676,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long seed_;
unsigned long long offset_;
bool is_dropout_;
bool is_lse_storing_ = true;
};
// Invoker
......@@ -692,7 +701,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
auto launch_kernel = [&](auto has_main_k_block_loop_,
auto is_dropout_,
auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
......@@ -715,7 +726,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
is_dropout_>;
is_dropout_,
is_lse_storing_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -755,26 +767,66 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else
{
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
......
......@@ -51,6 +51,7 @@ template <typename DataType,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmNWave = Free0_N / Gemm2NXdlPerWave / MPerXdl;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMLoop = Free1_M / Sum_M;
static constexpr index_t GemmMPack =
......
......@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
// if(Gemm1N != K)
//{
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false;
//}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment