Commit ec2ad713 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' into mha-train-bias-bwd-type2

parents e3eb4381 e296ee56
......@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = void;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......
......@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = void;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8,
true,
4,
S<8, 32, 1>, // B1BlockTransfer
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
......
......@@ -125,8 +125,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......@@ -259,8 +259,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......@@ -463,8 +463,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......
......@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = void;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......
......@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = void;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8,
true,
1,
S<8, 32, 1>, // B1BlockTransfer
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
......
......@@ -124,8 +124,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......@@ -258,8 +258,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......@@ -462,8 +462,8 @@ using DeviceGemmInstanceFWD =
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
void,
void,
AccDataType,
ShuffleDataType,
QKVElementOp,
......
......@@ -177,8 +177,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // std::array<void*, 1> p_acc0_biases;
nullptr, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......
......@@ -50,11 +50,10 @@ using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = DataType;
using DDataType = F16;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = F16;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......
......@@ -48,13 +48,12 @@ using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32;
using DDataType = F16;
using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>;
using Acc1BiasDataType = ck::Tuple<>;
using Acc0BiasDataType = F16;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......
......@@ -116,7 +116,7 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
......@@ -137,25 +137,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
......@@ -163,7 +163,7 @@ int run(int argc, char* argv[])
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
......@@ -183,15 +183,15 @@ int run(int argc, char* argv[])
// TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), //
nullptr,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......@@ -203,18 +203,18 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::vector<ck::index_t>
{}, // std::vector<ck::index_t>
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of
// elements on a thread
{seed, offset}); // dropout random seed and offset, offset should be at
// least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument))
{
......@@ -228,14 +228,15 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) *
std::size_t num_bytes =
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O +
sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
......@@ -243,16 +244,15 @@ int run(int argc, char* argv[])
if(do_verification)
{
// run for storing z tensor
argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()),
nullptr,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......@@ -264,20 +264,18 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
{},
{},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
......@@ -294,7 +292,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum)
Tensor<DDataType> d_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
......@@ -324,12 +322,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
});
// softmax
......
......@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
std::vector<std::vector<const void*>> p_d;
std::vector<const void*> p_d;
std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse;
......@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<DDataType>> d_tensors;
std::vector<Tensor<Acc0BiasDataType>> d_tensors;
std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors;
......@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
lse_gs_ms_strides,
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides
......@@ -159,7 +157,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
......@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (Acc0BiasDataType::Size() ? 0 : 1)) *
sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch;
if(i < 4)
......@@ -191,25 +189,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
}
a_tensors.push_back(a_gs_ms_ks);
......@@ -229,7 +227,7 @@ int run(int argc, char* argv[])
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
......@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back({d_tensors_device[i]->GetDeviceBuffer()});
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
......@@ -363,7 +359,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> d_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
......@@ -400,12 +396,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); });
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
});
// softmax
......
......@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -218,21 +218,19 @@ struct BlockwiseDropout
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
template <typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
}
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
......
......@@ -87,9 +87,6 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0,
......@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void* p_c,
void* p_z,
void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -110,12 +107,10 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
const std::vector<index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_strides,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
......
......@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<index_t> acc1_biases_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
......@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
std::vector<const void*> p_acc0_bias_vec,
std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
......
......@@ -289,12 +289,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
......@@ -535,15 +529,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(
const ADataType* p_a_grid,
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
ZDataType* p_z_grid,
LSEDataType* p_lse_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -555,12 +548,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
const std::vector<ck::index_t> acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t> acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_lengths,
const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -624,12 +615,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
ignore = p_acc0_bias;
ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -984,15 +975,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
static auto
MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
ZDataType* p_z,
LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1004,12 +995,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1024,8 +1013,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c,
p_z,
p_lse,
p_acc0_biases,
p_acc1_biases,
p_acc0_bias,
p_acc1_bias,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1037,10 +1026,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
......@@ -1061,8 +1050,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
void* p_c,
void* p_z,
void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1074,12 +1063,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1094,8 +1081,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
p_acc0_bias, // cast in struct Argument
p_acc1_bias, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1107,10 +1094,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths,
acc1_bias_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
......
......@@ -279,12 +279,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
......@@ -603,8 +597,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<const void*> p_acc0_bias_vec,
std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -619,6 +613,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
{
ignore = p_acc0_bias_vec;
ignore = p_acc1_bias_vec;
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
......@@ -628,11 +625,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++)
......@@ -710,18 +702,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_b1_grid,
......@@ -1055,8 +1035,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<const void*> p_acc0_bias_vec,
std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1072,8 +1052,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec,
p_z_vec,
p_lse_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1094,9 +1074,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
std::vector<const void*> p_acc0_bias_vec,
std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1111,8 +1091,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec,
p_z_vec,
p_lse_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment