"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "7673fe8307eadca597cc289045bf6f49724f2834"
Commit ec2ad713 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' into mha-train-bias-bwd-type2

parents e3eb4381 e296ee56
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
4, 4,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -125,8 +125,8 @@ using DeviceGemmInstanceFWD = ...@@ -125,8 +125,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -259,8 +259,8 @@ using DeviceGemmInstanceFWD = ...@@ -259,8 +259,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -463,8 +463,8 @@ using DeviceGemmInstanceFWD = ...@@ -463,8 +463,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
1, 1,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -124,8 +124,8 @@ using DeviceGemmInstanceFWD = ...@@ -124,8 +124,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -258,8 +258,8 @@ using DeviceGemmInstanceFWD = ...@@ -258,8 +258,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -462,8 +462,8 @@ using DeviceGemmInstanceFWD = ...@@ -462,8 +462,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
......
...@@ -177,8 +177,8 @@ int run(int argc, char* argv[]) ...@@ -177,8 +177,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
......
...@@ -50,11 +50,10 @@ using B1DataType = DataType; ...@@ -50,11 +50,10 @@ using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using DDataType = F16;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = F16;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -122,6 +121,7 @@ using DeviceGemmInstance = ...@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -195,6 +195,7 @@ using DeviceGemmInstance = ...@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -268,6 +269,7 @@ using DeviceGemmInstance = ...@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -48,13 +48,12 @@ using ADataType = DataType; ...@@ -48,13 +48,12 @@ using ADataType = DataType;
using B0DataType = DataType; using B0DataType = DataType;
using B1DataType = DataType; using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using DDataType = F16;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = F16;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -122,6 +121,7 @@ using DeviceGemmInstance = ...@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -195,6 +195,7 @@ using DeviceGemmInstance = ...@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -268,6 +269,7 @@ using DeviceGemmInstance = ...@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -116,7 +116,7 @@ int run(int argc, char* argv[]) ...@@ -116,7 +116,7 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -137,25 +137,25 @@ int run(int argc, char* argv[]) ...@@ -137,25 +137,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -163,7 +163,7 @@ int run(int argc, char* argv[]) ...@@ -163,7 +163,7 @@ int run(int argc, char* argv[])
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize()); lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
...@@ -181,40 +181,40 @@ int run(int argc, char* argv[]) ...@@ -181,40 +181,40 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), //
{}, // std::array<void*, 1> p_acc1_biases; nullptr,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides, b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths, b1_gs_os_ns_lengths,
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::vector<ck::index_t>
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::vector<ck::index_t>
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of {seed, offset}); // dropout random seed and offset, offset should be at
// elements on a thread // least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -227,15 +227,16 @@ int run(int argc, char* argv[]) ...@@ -227,15 +227,16 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_bytes =
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) * sizeof(CDataType) * M * O +
BatchCount; sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
...@@ -243,41 +244,38 @@ int run(int argc, char* argv[]) ...@@ -243,41 +244,38 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{ static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()),
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; nullptr,
{}, // std::array<void*, 1> p_acc1_biases; a_gs_ms_ks_lengths,
a_gs_ms_ks_lengths, a_gs_ms_ks_strides,
a_gs_ms_ks_strides, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides,
b0_gs_ns_ks_strides, b1_gs_os_ns_lengths,
b1_gs_os_ns_lengths, b1_gs_os_ns_strides,
b1_gs_os_ns_strides, c_gs_ms_os_lengths,
c_gs_ms_os_lengths, c_gs_ms_os_strides,
c_gs_ms_os_strides, z_gs_ms_ns_lengths,
z_gs_ms_ns_lengths, z_gs_ms_ns_strides,
z_gs_ms_ns_strides, lse_gs_ms_lengths,
lse_gs_ms_lengths, d_gs_ms_ns_lengths,
std::array<std::vector<ck::index_t>, 1>{ d_gs_ms_ns_strides,
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths {},
std::array<std::vector<ck::index_t>, 1>{ {},
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides a_element_op,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, b0_element_op,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, acc0_element_op,
a_element_op, b1_element_op,
b0_element_op, c_element_op,
acc0_element_op, p_drop, // dropout ratio
b1_element_op, {seed, offset}); // dropout random seed and offset, offset should be
c_element_op, // at least the number of elements on a thread
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
c_device_buf.SetZero(); c_device_buf.SetZero();
lse_device_buf.SetZero(); lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
...@@ -294,7 +292,7 @@ int run(int argc, char* argv[]) ...@@ -294,7 +292,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<DDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
...@@ -324,12 +322,12 @@ int run(int argc, char* argv[]) ...@@ -324,12 +322,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
}); });
// softmax // softmax
......
...@@ -57,7 +57,7 @@ int run(int argc, char* argv[]) ...@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<std::vector<const void*>> p_d; std::vector<const void*> p_d;
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
...@@ -67,7 +67,7 @@ int run(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<DDataType>> d_tensors; std::vector<Tensor<Acc0BiasDataType>> d_tensors;
std::vector<Tensor<ZDataType>> z_tensors; std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
...@@ -147,10 +147,8 @@ int run(int argc, char* argv[]) ...@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
std::vector<std::vector<ck::index_t>>{ d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_biases_gs_ms_os_strides
...@@ -159,7 +157,7 @@ int run(int argc, char* argv[]) ...@@ -159,7 +157,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -167,7 +165,7 @@ int run(int argc, char* argv[]) ...@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (Acc0BiasDataType::Size() ? 0 : 1)) * sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch; Batch;
if(i < 4) if(i < 4)
...@@ -191,25 +189,25 @@ int run(int argc, char* argv[]) ...@@ -191,25 +189,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
} }
a_tensors.push_back(a_gs_ms_ks); a_tensors.push_back(a_gs_ms_ks);
...@@ -229,7 +227,7 @@ int run(int argc, char* argv[]) ...@@ -229,7 +227,7 @@ int run(int argc, char* argv[])
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>( d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>( z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -244,9 +242,7 @@ int run(int argc, char* argv[]) ...@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back({d_tensors_device[i]->GetDeviceBuffer()}); p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr); p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
...@@ -363,7 +359,7 @@ int run(int argc, char* argv[]) ...@@ -363,7 +359,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
...@@ -400,12 +396,12 @@ int run(int argc, char* argv[]) ...@@ -400,12 +396,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
}); });
// softmax // softmax
......
...@@ -138,12 +138,12 @@ struct BlockwiseDropout ...@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -179,12 +179,12 @@ struct BlockwiseDropout ...@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -218,21 +218,19 @@ struct BlockwiseDropout ...@@ -218,21 +218,19 @@ struct BlockwiseDropout
} }
// get raw z matrix with random number for shuffle // get raw z matrix with random number for shuffle
template <typename ZThreadBuffer, template <typename ZThreadBuffer, typename Step, typename Offset>
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph, __host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{}); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
} }
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; }); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
......
...@@ -87,9 +87,6 @@ template <index_t NumDimG, ...@@ -87,9 +87,6 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{ {
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -110,12 +107,10 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -110,12 +107,10 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias> const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -289,12 +289,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -289,12 +289,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
...@@ -535,39 +529,36 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -535,39 +529,36 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const ADataType* p_a_grid,
const ADataType* p_a_grid, const BDataType* p_b_grid,
const BDataType* p_b_grid, const B1DataType* p_b1_grid,
const B1DataType* p_b1_grid, CDataType* p_c_grid,
CDataType* p_c_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, LSEDataType* p_lse_grid,
LSEDataType* p_lse_grid, const void* p_acc0_bias,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc1_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t> acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t> acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths AElementwiseOperation a_element_op,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> BElementwiseOperation b_element_op,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AccElementwiseOperation acc_element_op,
AElementwiseOperation a_element_op, B1ElementwiseOperation b1_element_op,
BElementwiseOperation b_element_op, CElementwiseOperation c_element_op,
AccElementwiseOperation acc_element_op, float p_dropout,
B1ElementwiseOperation b1_element_op, std::tuple<unsigned long long, unsigned long long> seeds)
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -624,12 +615,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -624,12 +615,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_bias;
ignore = p_acc1_biases; ignore = p_acc1_bias;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides; ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -984,39 +975,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -984,39 +975,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const ADataType* p_a, MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
ZDataType* p_z, ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> AElementwiseOperation a_element_op,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides BElementwiseOperation b_element_op,
AElementwiseOperation a_element_op, AccElementwiseOperation acc_element_op,
BElementwiseOperation b_element_op, B1ElementwiseOperation b1_element_op,
AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op,
B1ElementwiseOperation b1_element_op, float p_dropout,
CElementwiseOperation c_element_op, std::tuple<unsigned long long, unsigned long long> seeds)
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1024,8 +1013,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1024,8 +1013,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c, p_c,
p_z, p_z,
p_lse, p_lse,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1037,10 +1026,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1037,10 +1026,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1061,8 +1050,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1061,8 +1050,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1074,12 +1063,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1074,12 +1063,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1094,8 +1081,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1094,8 +1081,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse), static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument p_acc0_bias, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1107,10 +1094,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1107,10 +1094,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
...@@ -279,12 +279,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -279,12 +279,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
...@@ -603,8 +597,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -603,8 +597,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -619,6 +613,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -619,6 +613,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
ignore = p_acc0_bias_vec;
ignore = p_acc1_bias_vec;
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
...@@ -628,11 +625,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -628,11 +625,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
...@@ -710,18 +702,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -710,18 +702,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -1055,8 +1035,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1055,8 +1035,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1072,8 +1052,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1072,8 +1052,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec, p_c_vec,
p_z_vec, p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_bias_vec,
p_acc1_biases_vec, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1094,9 +1074,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1094,9 +1074,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1111,8 +1091,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1111,8 +1091,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec, p_c_vec,
p_z_vec, p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_bias_vec,
p_acc1_biases_vec, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment