Commit b158c537 authored by letaoqin's avatar letaoqin
Browse files

add parameters device

parent 81f4481c
...@@ -229,7 +229,8 @@ int run(int argc, char* argv[]) ...@@ -229,7 +229,8 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -57,6 +57,7 @@ int run(int argc, char* argv[]) ...@@ -57,6 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<const void*> p_d;
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
...@@ -66,6 +67,7 @@ int run(int argc, char* argv[]) ...@@ -66,6 +67,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<DDataType>> d_tensors;
std::vector<Tensor<ZDataType>> z_tensors; std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
...@@ -74,6 +76,7 @@ int run(int argc, char* argv[]) ...@@ -74,6 +76,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> b0_tensors_device; std::vector<DeviceMemPtr> b0_tensors_device;
std::vector<DeviceMemPtr> b1_tensors_device; std::vector<DeviceMemPtr> b1_tensors_device;
std::vector<DeviceMemPtr> c_tensors_device; std::vector<DeviceMemPtr> c_tensors_device;
std::vector<DeviceMemPtr> d_tensors_device;
std::vector<DeviceMemPtr> z_tensors_device; std::vector<DeviceMemPtr> z_tensors_device;
std::vector<DeviceMemPtr> lse_tensors_device; std::vector<DeviceMemPtr> lse_tensors_device;
...@@ -116,6 +119,12 @@ int run(int argc, char* argv[]) ...@@ -116,6 +119,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
...@@ -138,8 +147,10 @@ int run(int argc, char* argv[]) ...@@ -138,8 +147,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
{}, // acc0_biases_gs_ms_ns_lengths std::vector<std::vector<ck::index_t>>{
{}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_biases_gs_ms_os_strides
...@@ -148,13 +159,15 @@ int run(int argc, char* argv[]) ...@@ -148,13 +159,15 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1; int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (Acc0BiasDataType::Size() ? 0 : 1)) *
Batch; Batch;
if(i < 4) if(i < 4)
...@@ -177,27 +190,32 @@ int run(int argc, char* argv[]) ...@@ -177,27 +190,32 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
} }
a_tensors.push_back(a_gs_ms_ks); a_tensors.push_back(a_gs_ms_ks);
b0_tensors.push_back(b0_gs_ns_ks); b0_tensors.push_back(b0_gs_ns_ks);
b1_tensors.push_back(b1_gs_os_ns); b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
d_tensors.push_back(d_gs_ms_ns);
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms_device_result); lse_tensors.push_back(lse_gs_ms_device_result);
...@@ -209,6 +227,8 @@ int run(int argc, char* argv[]) ...@@ -209,6 +227,8 @@ int run(int argc, char* argv[])
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>( z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -217,11 +237,13 @@ int run(int argc, char* argv[]) ...@@ -217,11 +237,13 @@ int run(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data()); a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
d_tensors_device[i]->ToDevice(d_gs_ms_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr); p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
...@@ -244,7 +266,7 @@ int run(int argc, char* argv[]) ...@@ -244,7 +266,7 @@ int run(int argc, char* argv[])
p_c, p_c,
p_z_nullptr, p_z_nullptr,
p_lse, p_lse,
{}, // p_acc0_biases std::vector<std::vector<const void*>>{p_d}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
problem_descs, problem_descs,
a_element_op, a_element_op,
......
...@@ -285,12 +285,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -285,12 +285,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumD0Tensor = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumAcc0Bias <= 1, "Bias0 addition is only support one bias"); static_assert(NumD0Tensor <= 1, "Bias0 addition is only support one bias");
static_assert(NumAcc1Bias == 0, "Bias addition is unimplemented"); static_assert(NumD1Tensor == 0, "Bias addition is unimplemented");
static_assert(NumD0Tensor == 0
? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType;
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -398,12 +402,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -398,12 +402,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
...@@ -429,12 +435,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -429,12 +435,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const DGridDesc_G_M_N& d_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
d_grid_desc_g_m_n_(d_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
...@@ -460,6 +468,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -460,6 +468,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx) const
{
return d_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -475,6 +488,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -475,6 +488,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
DGridDesc_G_M_N d_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
...@@ -558,6 +572,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -558,6 +572,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const DDataType* p_d_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_; LSEDataType* p_lse_grid_;
...@@ -567,6 +582,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -567,6 +582,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
DGridDesc_M_N d_grid_desc_m_n_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
...@@ -653,6 +671,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -653,6 +671,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_d_grid = NumD0Tensor == 0
? nullptr
: static_cast<const DDataType*>(p_acc0_biases_vec[i][0]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]); const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
...@@ -671,6 +692,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -671,6 +692,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_m_n =
NumD0Tensor == 0
? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d_grid_desc_m_n);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto lse_grid_desc_m = const auto lse_grid_desc_m =
...@@ -684,6 +713,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -684,6 +713,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_g_m_n =
NumD0Tensor == 0 ? DGridDesc_G_M_N{}
: Transform::MakeCGridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
...@@ -713,6 +747,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -713,6 +747,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
d_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
...@@ -722,12 +757,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -722,12 +757,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumD0Tensor and
// so on // so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias && if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumD0Tensor &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias && problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumD0Tensor &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias && problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumD1Tensor &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias)) problem_desc.acc1_biases_gs_ms_os_strides.size() == NumD1Tensor))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! number of biases in function argument does not " "wrong! number of biases in function argument does not "
...@@ -743,12 +778,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -743,12 +778,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_d_grid,
p_z_grid, p_z_grid,
p_lse_grid, p_lse_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m_n,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n, z_grid_desc_m_n,
lse_grid_desc_m, lse_grid_desc_m,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment