Commit 381a7317 authored by letaoqin's avatar letaoqin
Browse files

redefine interface

parent 6fa4feac
...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -53,7 +53,7 @@ using CDataType = DataType; ...@@ -53,7 +53,7 @@ using CDataType = DataType;
using DDataType = F16; using DDataType = F16;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<DDataType>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
...@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false; ...@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -151,7 +151,7 @@ using DeviceGemmInstance = ...@@ -151,7 +151,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -222,7 +222,7 @@ using DeviceGemmInstance = ...@@ -222,7 +222,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -137,7 +137,7 @@ int run(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1,1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
...@@ -188,11 +188,10 @@ int run(int argc, char* argv[]) ...@@ -188,11 +188,10 @@ int run(int argc, char* argv[])
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(nullptr),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
...@@ -201,13 +200,11 @@ int run(int argc, char* argv[]) ...@@ -201,13 +200,11 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op, a_element_op,
...@@ -244,18 +241,17 @@ int run(int argc, char* argv[]) ...@@ -244,18 +241,17 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; std::array<void*, 1>{
{}, // std::array<void*, 1> p_acc1_biases; d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
...@@ -264,13 +260,13 @@ int run(int argc, char* argv[]) ...@@ -264,13 +260,13 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, std::array<std::vector<ck::index_t>, 1>{
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op, a_element_op,
...@@ -326,11 +322,8 @@ int run(int argc, char* argv[]) ...@@ -326,11 +322,8 @@ int run(int argc, char* argv[])
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
//bias acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += d_g_m_n(idx);
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......
...@@ -127,70 +127,6 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -127,70 +127,6 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionBiasForward : public BaseOperator
{
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
const void* p_d,
void* p_z,
void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& d_gs_ms_ns_lengths, // d_gs_ms_os_lengths
const std::vector<index_t>& d_gs_ms_ns_strides, // d_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -124,9 +124,9 @@ __global__ void ...@@ -124,9 +124,9 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_d_grid + d_batch_offset, p_d_grid == nullptr ? nullptr : p_d_grid + d_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -157,9 +157,9 @@ __global__ void ...@@ -157,9 +157,9 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_d_grid + d_batch_offset, p_d_grid == nullptr ? nullptr : p_d_grid + d_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -288,28 +288,27 @@ template <index_t NumDimG, ...@@ -288,28 +288,27 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
: public DeviceBatchedMultiheadAttentionBiasForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MaskingSpec> MaskingSpec>
{ {
using DDataType = ADataType;
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
...@@ -317,7 +316,12 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -317,7 +316,12 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias <= 1, "Acc0 Bias addition is max support one bias");
static_assert(NumAcc1Bias == 0, "Acc1 Bias addition is unimplemented");
static_assert(NumAcc1Bias == 0
? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType;
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -329,7 +333,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -329,7 +333,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -574,7 +578,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -574,7 +578,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const BDataType* p_b_grid, const BDataType* p_b_grid,
const B1DataType* p_b1_grid, const B1DataType* p_b1_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const DDataType* p_d_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
LSEDataType* p_lse_grid, LSEDataType* p_lse_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
...@@ -587,8 +590,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -587,8 +590,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
...@@ -609,7 +610,8 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -609,7 +610,8 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_d_grid_{p_d_grid}, p_d_grid_{NumAcc0Bias == 0 ? nullptr
: static_cast<const DDataType*>(p_acc0_biases[0])},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_lse_grid_{p_lse_grid}, p_lse_grid_{p_lse_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
...@@ -620,7 +622,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -620,7 +622,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_m_n_{MakeZGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides)}, d_grid_desc_m_n_{NumAcc0Bias == 0
? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
...@@ -631,8 +636,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -631,8 +636,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_g_m_n_{ d_grid_desc_g_m_n_{NumAcc0Bias == 0 ? DGridDesc_G_M_N{}
Transform::MakeCGridDescriptor_G_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides)}, : Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
...@@ -666,10 +673,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -666,10 +673,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases; ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_biases_gs_ms_gemm1ns_strides;
...@@ -1052,7 +1056,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1052,7 +1056,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const BDataType* p_b, const BDataType* p_b,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
const DDataType* p_d,
ZDataType* p_z, ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
...@@ -1065,8 +1068,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1065,8 +1068,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
...@@ -1088,7 +1089,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1088,7 +1089,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
p_b, p_b,
p_b1, p_b1,
p_c, p_c,
p_d,
p_z, p_z,
p_lse, p_lse,
p_acc0_biases, p_acc0_biases,
...@@ -1101,8 +1101,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1101,8 +1101,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
...@@ -1128,7 +1126,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1128,7 +1126,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const void* p_b, const void* p_b,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
const void* p_d,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
...@@ -1141,8 +1138,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1141,8 +1138,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
...@@ -1164,7 +1159,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1164,7 +1159,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const DDataType*>(p_d),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse), static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
...@@ -1177,8 +1171,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1177,8 +1171,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
...@@ -1207,7 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2 ...@@ -1207,7 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2" str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment