Commit fa066d60 authored by letaoqin's avatar letaoqin
Browse files

gridwise change to multiple D

parent ee275d4d
...@@ -25,6 +25,7 @@ namespace ck { ...@@ -25,6 +25,7 @@ namespace ck {
* *
*/ */
template <typename FloatAB, template <typename FloatAB,
typename D0sDataType,
typename ZDataType, typename ZDataType,
typename FloatGemm, typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
...@@ -39,6 +40,7 @@ template <typename FloatAB, ...@@ -39,6 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0sGridDesc_M_N,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
...@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"); "D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
using DDataType = FloatAB; static constexpr index_t NumD0Tensor = D0sDataType::Size();
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_m_n); c_grid_desc_m_n);
} }
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
...@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
typename C0MatrixMask> typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
D0sGridPointer p_d0s_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const DDataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d matrix) // bias (d matrix)
constexpr auto d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId I1, // NBlockId
m0, // MRepeat m0, // MRepeat
...@@ -919,36 +968,49 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -919,36 +968,49 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
n3, // NInputNum n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
auto d_threadwise_copy_globla_vgpr = auto d0s_threadwise_copy = generate_tuple(
ThreadwiseTensorSliceTransfer_v2<DDataType, [&](auto i) {
DDataType, using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
decltype(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), return ThreadwiseTensorSliceTransfer_v2<
decltype(d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), D0DataType,
Sequence<I1, // MBlockId D0DataType,
I1, // NBlockID decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
m0, // MRepeat decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
n0, // NRepeat Sequence<I1, // MBlockId
m1, // MWaveId I1, // NBlockID
n1, // NWaveId m0, // MRepeat
m2, // MPerXdl n0, // NRepeat
n2, // NGroupNum m1, // MWaveId
n3, // NInputNum n1, // NWaveId
n4>, m2, // MPerXdl
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, n2, // NGroupNum
9, n3, // NInputNum
D0BlockTransferSrcScalarPerVector, n4>,
1, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
false>(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, 9,
make_multi_index(block_work_idx_m, // MBlockId D0BlockTransferSrcScalarPerVector,
0, // NBlockId 1,
0, // mrepeat false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
0, // nrepeat make_multi_index(block_work_idx_m, // MBlockId
wave_id[I0], // MWaveId 0, // NBlockId
wave_id[I1], // NWaveId 0, // mrepeat
wave_m_n_id[I1], // MPerXdl 0, // nrepeat
0, // group wave_id[I0], // MWaveId
wave_m_n_id[I0], // NInputIndex wave_id[I1], // NWaveId
0)); // register number wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
// z is random number matrix for dropout verify // z is random number matrix for dropout verify
// //
...@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if(p_d_grid) static_for<0, NumD0Tensor, 1>{}([&](auto i) {
{ // get register
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
p_d_grid, d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType, D0DataType,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
d_thread_buf; d0_thread_buf;
d_threadwise_copy_globla_vgpr.Run( // load data from global
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d_grid_buf, d0s_grid_buf[i],
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf); d0_thread_buf);
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// acc add bias
acc_thread_buf(i) += d_thread_buf[i];
});
d_threadwise_copy_globla_vgpr.MoveSrcSliceWindow( // acc add bias
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, static_for<0, m0 * n0 * n2 * n4, 1>{}(
[&](auto j) { acc_thread_buf(j) += d0_thread_buf[j]; });
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} });
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment