"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "aeb2cab1cd593231ef38f80827cb502973825354"
Commit 100c4bb3 authored by Jing Zhang's avatar Jing Zhang
Browse files

add empty Ds

parent c58a3877
...@@ -29,34 +29,39 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -29,34 +29,39 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t; using ADataType = F16;
using BDataType = ck::half_t; using BDataType = F16;
using CDataType = ck::half_t; using AccDataType = F32;
using AccDataType = float; using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = Row;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = Col;
using CLayout = ck::tensor_layout::gemm::RowMajor; using ELayout = Row;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
//######| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| EData| A| B| CE| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // clang-format off
//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -93,7 +98,11 @@ int main(int argc, char* argv[]) ...@@ -93,7 +98,11 @@ int main(int argc, char* argv[])
int N = 128 + 128 * i; int N = 128 + 128 * i;
int K = 64 + 64 * i; int K = 64 + 64 * i;
gemm_descs.push_back({M, N, K, K, K, N}); int stride_A = K;
int stride_B = K;
int stride_C = N;
gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -111,10 +120,9 @@ int main(int argc, char* argv[]) ...@@ -111,10 +120,9 @@ int main(int argc, char* argv[])
}; };
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors; std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors; std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors; std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count); a_tensors.reserve(group_count);
b_tensors.reserve(group_count); b_tensors.reserve(group_count);
...@@ -137,10 +145,10 @@ int main(int argc, char* argv[]) ...@@ -137,10 +145,10 @@ int main(int argc, char* argv[])
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
...@@ -149,7 +157,7 @@ int main(int argc, char* argv[]) ...@@ -149,7 +157,7 @@ int main(int argc, char* argv[])
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(init_method) switch(init_method)
{ {
...@@ -175,7 +183,7 @@ int main(int argc, char* argv[]) ...@@ -175,7 +183,7 @@ int main(int argc, char* argv[])
b_tensors_device.emplace_back( b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace())); sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpace()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
...@@ -187,14 +195,16 @@ int main(int argc, char* argv[]) ...@@ -187,14 +195,16 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<std::vector<const void*>> p_Ds = {};
// do GEMM // do GEMM
auto argument = auto argument = gemm.MakeArgument(
gemm.MakeArgument(p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......
...@@ -12,14 +12,17 @@ struct GemmDesc ...@@ -12,14 +12,17 @@ struct GemmDesc
{ {
ck::index_t M_, N_, K_; ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_, stride_C_; ck::index_t stride_A_, stride_B_, stride_C_;
std::vector<ck::index_t> stride_Ds_;
}; };
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -28,7 +31,8 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -28,7 +31,8 @@ struct DeviceGroupedGemm : public BaseOperator
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a, MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<std::vector<const void*>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmDesc>& gemm_desc, std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -39,19 +43,21 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -39,19 +43,21 @@ struct DeviceGroupedGemm : public BaseOperator
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr<DeviceGroupedGemm<ALayout, using DeviceGroupedGemmPtr = std::unique_ptr<DeviceGroupedGemm<ALayout,
BLayout, BLayout,
CLayout, DELayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CElementwiseOperation>>;
......
...@@ -25,7 +25,7 @@ template <typename GridwiseGemm, ...@@ -25,7 +25,7 @@ template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -35,7 +35,7 @@ __global__ void ...@@ -35,7 +35,7 @@ __global__ void
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -67,7 +67,7 @@ __global__ void ...@@ -67,7 +67,7 @@ __global__ void
gemm_desc_ptr[group_id].a_ptr_, gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_, gemm_desc_ptr[group_id].b_ptr_,
ck::Tuple<>{}, ck::Tuple<>{},
gemm_desc_ptr[group_id].c_ptr_, gemm_desc_ptr[group_id].e_ptr_,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -90,15 +90,16 @@ __global__ void ...@@ -90,15 +90,16 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch, ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
...@@ -132,14 +133,17 @@ template <typename ALayout, ...@@ -132,14 +133,17 @@ template <typename ALayout,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
BLayout, BLayout,
CLayout, DELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CDEElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -353,12 +357,12 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -353,12 +357,12 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1)); make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE)); make_tuple(I1, StrideE));
...@@ -415,13 +419,13 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -415,13 +419,13 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<>, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
...@@ -510,11 +514,17 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -510,11 +514,17 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
GroupedGemmBlock2ETileMap block_2_ctile_map_; GroupedGemmBlock2ETileMap block_2_ctile_map_;
const ADataType* a_ptr_; const ADataType* a_ptr_;
const BDataType* b_ptr_; const BDataType* b_ptr_;
EDataType* c_ptr_; typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
ck::index_t BlockStart_, BlockEnd_; ck::index_t BlockStart_, BlockEnd_;
}; };
...@@ -524,11 +534,12 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -524,11 +534,12 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
{ {
Argument(std::vector<const void*>& p_As, Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation c_element_op)
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{ {
grid_size_ = 0; grid_size_ = 0;
...@@ -582,15 +593,40 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -582,15 +593,40 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
auto e_grid_desc_mblock_mperblock_nblock_nperblock_ = auto e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of
// different
typename GridwiseGemm::DsGridPointer p_ds_grid_;
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_Ds[i][j]);
const auto d_grid_desc_m_n = GridwiseGemm::MakeEGridDescriptor_M_N(
M, N, gemm_descs[i].stride_Ds_[j]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(j) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
}
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{a_grid_desc_k0_m_k1_, GemmBiasTransKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
ds_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map_, block_2_ctile_map_,
static_cast<const ADataType*>(p_As[i]), static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]), static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid_,
static_cast<EDataType*>(p_Es[i]), static_cast<EDataType*>(p_Es[i]),
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
...@@ -602,7 +638,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -602,7 +638,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
index_t group_count_; index_t group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_; std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
...@@ -666,7 +702,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -666,7 +702,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
GemmBiasTransKernelArg, GemmBiasTransKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -724,13 +760,15 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -724,13 +760,15 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static auto MakeArgument(std::vector<const void*>& p_As, static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs, std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CDEElementwiseOperation c_element_op)
{ {
return Argument{p_As, p_Bs, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -738,14 +776,15 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout, ...@@ -738,14 +776,15 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_As, std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::vector<const void*>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CDEElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(
p_As, p_Bs, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -16,39 +16,71 @@ namespace tensor_operation { ...@@ -16,39 +16,71 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using DsType = Tuple<>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
DeviceGroupedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
DeviceGroupedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
DeviceGroupedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& Row,
instances); Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
DeviceGroupedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& Col,
instances); Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType> typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>> ck::tensor_operation::element_wise::PassThrough>>
...@@ -58,7 +90,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -58,7 +90,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -68,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -68,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
......
...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -30,30 +32,37 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -30,30 +32,37 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> DeviceGroupedGemmXdl< Row, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
DeviceGroupedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& Col,
instances) Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
......
...@@ -24,7 +24,7 @@ namespace profiler { ...@@ -24,7 +24,7 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -67,7 +67,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -67,7 +67,7 @@ bool profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<ADataType>> a_m_k; std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n; std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_device_results; std::vector<Tensor<EDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -77,7 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -77,7 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification,
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back( c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
...@@ -96,7 +96,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -96,7 +96,7 @@ bool profile_grouped_gemm_impl(int do_verification,
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread); c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<EDataType>{}, num_thread);
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -138,13 +138,13 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -138,13 +138,13 @@ bool profile_grouped_gemm_impl(int do_verification,
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>( c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace())); sizeof(EDataType) * c_m_n_device_results[i].mDesc.GetElementSpace()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
...@@ -156,7 +156,8 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -156,7 +156,8 @@ bool profile_grouped_gemm_impl(int do_verification,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, ck::Tuple<>,
EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
...@@ -174,12 +175,15 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -174,12 +175,15 @@ bool profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
auto p_ds = std::vector<std::vector<const void*>>{};
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : op_ptrs) for(auto& gemm_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a, gemm_ptr->MakeArgumentPointer(p_a,
p_b, p_b,
p_ds,
p_c, p_c,
gemm_descs, gemm_descs,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
...@@ -205,7 +209,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -205,7 +209,7 @@ bool profile_grouped_gemm_impl(int do_verification,
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i]; sizeof(EDataType) * Ms[i] * Ns[i];
} }
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -229,13 +233,13 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -229,13 +233,13 @@ bool profile_grouped_gemm_impl(int do_verification,
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Tensor<CDataType> c_m_n_host_result( Tensor<EDataType> c_m_n_host_result(
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment