Commit 24c7c49f authored by Adam Osewski's avatar Adam Osewski
Browse files

Introduce Device GroupedGemmSplitK

parent 53d3ee2b
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
...@@ -118,17 +118,17 @@ template <typename ALayout, ...@@ -118,17 +118,17 @@ template <typename ALayout,
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> && enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
is_same_v<DsDataType, ck::Tuple<>>, is_same_v<DsDataType, ck::Tuple<>>,
bool> = false> bool> = false>
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemm<ALayout, struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -183,9 +183,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemm<ALayout, ...@@ -183,9 +183,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemm<ALayout,
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
using KernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment