Commit 98212afe authored by letaoqin's avatar letaoqin
Browse files

batched gemm reduce interface

parent 102c9661
......@@ -53,8 +53,8 @@ using CDataType = DataType;
using DDataType = F16;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = DDataType;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......
......@@ -190,8 +190,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), //
nullptr,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......@@ -203,10 +203,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::vector<ck::index_t>
{}, // std::vector<ck::index_t>
a_element_op,
b0_element_op,
acc0_element_op,
......@@ -230,7 +230,7 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) *
sizeof(DDataType) * M * N * std::is_void<DDataType>::value?1:0) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -250,9 +250,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()),
nullptr,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......@@ -264,12 +263,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
{},
{},
a_element_op,
b0_element_op,
acc0_element_op,
......
......@@ -87,8 +87,8 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t NumAcc0Bias = 1;
static constexpr index_t NumAcc1Bias = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
......@@ -97,8 +97,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void* p_c,
void* p_z,
void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const void* p_acc0_biases,
const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -110,11 +110,11 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
......
......@@ -25,7 +25,7 @@ namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename D0sPointer,
typename D0DataType,
typename FloatC,
typename ZDataType,
typename FloatLSE,
......@@ -37,7 +37,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
......@@ -56,7 +56,7 @@ __global__ void
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
D0sPointer p_d0s_grid,
const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -68,8 +68,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -107,11 +107,8 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
// const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, offset);
......@@ -125,7 +122,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_d0s_grid,
p_d0_grid == nullptr ? nullptr : p_d0_grid + d0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
......@@ -138,7 +135,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
......@@ -158,7 +155,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_d0s_grid,
p_d0_grid == nullptr ? nullptr : p_d0_grid + d0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
......@@ -171,7 +168,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
......@@ -188,7 +185,7 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_d0s_grid;
ignore = p_d0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = p_z_grid;
......@@ -200,7 +197,7 @@ __global__ void
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6;
......@@ -318,11 +315,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumD0Tensor = Acc0BiasDataType::Size();
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented");
static_assert(std::is_void<Acc1BiasDataType>::value, "Acc1 Bias addition is unimplemented");
using D0DataType = Acc0BiasDataType;
using D1DataType = Acc1BiasDataType;
#if 0
// TODO ANT: use alias
......@@ -406,40 +402,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
static auto MakeD0sGridDescriptor_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
using D0GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
......@@ -465,14 +437,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
......@@ -490,11 +462,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
......@@ -520,7 +490,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
......@@ -545,7 +515,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N,
D0GridDesc_M_N,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
ZGridDesc_M_N,
......@@ -605,15 +575,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(
const ADataType* p_a_grid,
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
ZDataType* p_z_grid,
LSEDataType* p_lse_grid,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -625,11 +594,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -640,6 +609,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_z_grid_{p_z_grid},
......@@ -648,6 +618,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides)},
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n_)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
......@@ -658,7 +633,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
......@@ -690,7 +665,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0s_grid_desc_g_m_n_,
d0_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
......@@ -711,23 +686,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n);
D0GridDesc_M_N d0_grid_desc_m_n{};
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n);
}
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
});
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
......@@ -769,7 +735,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
const D0DataType* p_d0_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
ZDataType* p_z_grid_;
......@@ -778,6 +744,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
D0GridDesc_M_N d0_grid_desc_m_n_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_;
......@@ -785,9 +754,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
......@@ -833,7 +801,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t n_raw_padded_;
// raw data
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride_;
std::vector<ck::index_t> d0_n_length_stride_;
};
// Invoker
......@@ -864,7 +832,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::D0sGridPointer,
D0DataType,
CDataType,
ZDataType,
LSEDataType,
......@@ -876,7 +844,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
......@@ -897,7 +865,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_d0s_grid_,
arg.p_d0_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_z_grid_,
......@@ -909,7 +877,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
......@@ -1040,18 +1008,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return false;
}
for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_n_length_stride_[i][1] == 1 &&
arg.d0s_n_length_stride_[i][0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0s_n_length_stride_[i][1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......@@ -1103,15 +1068,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
static auto
MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
ZDataType* p_z,
LSEDataType* p_lse,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const D0DataType* p_acc0_biases,
const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1123,11 +1088,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1180,8 +1145,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
void* p_c,
void* p_z,
void* p_lse,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const void* p_acc0_biases,
const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1193,11 +1158,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumD1Tensor>
const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1207,14 +1172,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
return std::make_unique<Argument>(
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -25,7 +25,7 @@ namespace ck {
*
*/
template <typename FloatAB,
typename D0sDataType,
typename D0DataType,
typename ZDataType,
typename FloatGemm,
typename FloatGemmAcc,
......@@ -40,7 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename D0sGridDesc_M_N,
typename D0GridDesc_M_N,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename ZGridDesc_M_N,
......@@ -102,7 +102,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -441,20 +440,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_m_n);
}
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
......@@ -472,20 +460,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0GridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
......@@ -544,7 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
D0sGridPointer p_d0s_grid,
const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -557,8 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -985,13 +961,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n3, // NInputNum
n4)); // RegisterNum
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
D0DataType,
auto d0_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<D0DataType,
D0DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId
I1, // NBlockID
......@@ -1007,7 +980,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
9,
D0BlockTransferSrcScalarPerVector,
1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
false>(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
......@@ -1018,16 +991,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
......@@ -1325,9 +1291,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
if constexpr(std::is_void<D0DataType>::value)
{
// get register
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
......@@ -1335,20 +1301,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_thread_buf;
// load data from global
d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_grid_buf[i],
d0_threadwise_copy.Run(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_buf,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// acc add bias
static_for<0, m0 * n0 * n2 * n4, 1>{}(
[&](auto j) { acc_thread_buf(j) += d0_thread_buf[j]; });
[&](auto i) { acc_thread_buf(i) += d0_thread_buf[i]; });
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0_threadwise_copy.MoveSrcSliceWindow(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
});
}
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment