Commit e663a0d7 authored by letaoqin's avatar letaoqin
Browse files

v1 gridwise gemm

parent 701879d0
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -104,11 +104,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 64)
// clang-format off
......
......@@ -273,12 +273,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 128;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
bool input_permute = false;
bool output_permute = false;
......@@ -420,6 +420,7 @@ int run(int argc, char* argv[])
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
......
......@@ -21,8 +21,8 @@
namespace ck {
template <typename InputDataType,
typename OutputDataType,
typename D0DataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
......@@ -1249,6 +1249,98 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
{
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
};
static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
D0M2,
I1));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2, D0M2, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3;
}
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_desc_n0_n1_m0_m1_m2_m3 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
using D0ThreadCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
2, // SrcScalarPerVector
2>;
};
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -1290,8 +1382,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const index_t raw_n_padded,
const index_t block_idx_n)
{
ignore = p_d0_grid;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
......@@ -1827,6 +1917,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
do
{
auto m_block_data_idx_on_grid =
......@@ -1974,6 +2076,57 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
// load k
gemm_tile_k_blockwise_copy.RunWrite(GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
......
......@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
// template<typename DataType>
struct D0Loader
{
template <typename DataType>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment