Unverified Commit bf80ceee authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

Merge branch 'attn-fwd-dropout-verify' into attn-bwd-dropout

parents 9dc05eaa a2887630
......@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -42,6 +43,7 @@ using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
std::vector<void*> p_z;
std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o;
......@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
......@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> b0_tensors_device;
std::vector<DeviceMemPtr> b1_tensors_device;
std::vector<DeviceMemPtr> c_tensors_device;
std::vector<DeviceMemPtr> z_tensors_device;
std::vector<DeviceMemPtr> lse_tensors_device;
std::size_t flop = 0, num_byte = 0;
......@@ -102,6 +105,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
......@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
lse_gs_ms_strides,
{}, // acc0_biases_gs_ms_ns_lengths
......@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1;
......@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl;
}
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method)
{
case 0: break;
......@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
b0_tensors.push_back(b0_gs_ns_ks);
b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result);
z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms_device_result);
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
......@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(LSEDataType) * lse_gs_ms_device_result.mDesc.GetElementSpaceSize()));
......@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
}
......@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
p_b0,
p_b1,
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
......
......@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
......@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
std::vector<index_t> z_gs_ms_ns_lengths;
std::vector<index_t> z_gs_ms_ns_strides;
std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides;
......@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
......
......@@ -35,6 +35,7 @@ template <typename FloatAB,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename ZGridDesc_M_N,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
......@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
......@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
......@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
///////////////////=>z for dropout
do
{
auto n_block_data_idx_on_grid =
......@@ -876,8 +1017,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if constexpr(IsDropout) // dropout
{
blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
true>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), true>(
acc_thread_buf, ph);
}
}
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment