Commit 24c7c49f authored by Adam Osewski's avatar Adam Osewski
Browse files

Introduce Device GroupedGemmSplitK

parent 53d3ee2b
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -15,7 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
......@@ -118,7 +118,7 @@ template <typename ALayout,
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
is_same_v<DsDataType, ck::Tuple<>>,
bool> = false>
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemm<ALayout,
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
......@@ -184,7 +184,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemm<ALayout,
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment