Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -7,7 +10,7 @@ ...@@ -7,7 +10,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -32,7 +35,7 @@ template <typename ADataType, ...@@ -32,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector, index_t DScalarPerVector,
index_t EScalarPerVector, index_t EScalarPerVector,
index_t FScalarPerVector> index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator ...@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(std::array<const void*, 5> p_inputs,
const BDataType* p_b, std::array<void*, 1> p_outputs,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
...@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator ...@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides, std::vector<index_t> f_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_inputs[0]),
p_b, static_cast<const BDataType*>(p_inputs[1]),
p_c, static_cast<const CDataType*>(p_inputs[2]),
p_d, static_cast<const DDataType*>(p_inputs[3]),
p_e, static_cast<const EDataType*>(p_inputs[4]),
p_f, static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, a_strides,
b_strides, b_strides,
...@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator ...@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor}; functor};
} }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 5> p_inputs,
const void* p_c, std::array<void*, 1> p_outputs,
const void* p_d, std::vector<index_t> lengths,
const void* p_e, std::vector<std::vector<index_t>> input_strides,
void* p_f, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> lengths, ElementwiseFunctor functor) override
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_c), static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_d), static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_e), static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_f), static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, input_strides[2],
d_strides, input_strides[3],
e_strides, input_strides[4],
f_strides, output_strides[0],
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
}; {
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <string> #include <string>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmPtr = std::unique_ptr<
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -20,16 +23,16 @@ namespace device { ...@@ -20,16 +23,16 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock,
typename ComputeBasePrtOfBatch, typename ComputeBasePrtOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop>
...@@ -41,18 +44,18 @@ __global__ void ...@@ -41,18 +44,18 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const ReduceInElementwiseOperations reduce_in_element_ops,
const DxsReduceAccElementwiseOperation dxs_out_element_op, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -68,10 +71,10 @@ __global__ void ...@@ -68,10 +71,10 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
}); });
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -79,36 +82,36 @@ __global__ void ...@@ -79,36 +82,36 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_ds_grid, p_reduces_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_ds_grid; ignore = p_reduces_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = dxs_in_element_op; ignore = reduce_in_element_ops;
ignore = dxs_out_element_op; ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = reduce_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_; ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) #endif
} }
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
...@@ -123,14 +126,14 @@ template <typename ALayout, ...@@ -123,14 +126,14 @@ template <typename ALayout,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -165,12 +168,7 @@ template <typename ALayout, ...@@ -165,12 +168,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
: public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -443,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -443,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
} }
// assume D is packed tensor // assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeReduceGridDescriptor_M(index_t MRaw)
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -471,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -471,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
...@@ -524,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -524,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, ReduceOperations,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
DGridDesc_M, ReduceGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -579,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -579,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -589,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -589,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, ReduceInElementwiseOperations reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op, ReduceAccElementwiseOperations reduce_out_element_ops,
index_t BatchCount) index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid}, p_reduces_grid_{p_reduces_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, reduce_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())}, type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op}, reduce_in_element_ops_{reduce_in_element_ops},
dxs_out_element_op_{dxs_out_element_op} reduce_out_element_ops_{reduce_out_element_ops}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -624,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -624,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
} }
} }
...@@ -633,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -633,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_; ReducePtrsGlobal p_reduces_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; ReduceInElementwiseOperations reduce_in_element_ops_;
DxsReduceAccElementwiseOperation dxs_out_element_op_; ReduceAccElementwiseOperations reduce_out_element_ops_;
}; };
// Invoker // Invoker
...@@ -660,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -660,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
{ {
#if 0 #if 0
{ {
std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl; std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -675,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -675,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl; << std::endl;
} }
#endif #endif
...@@ -689,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -689,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -701,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -701,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -724,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -724,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -744,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -744,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -767,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -767,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -821,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -821,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
} }
} }
static auto MakeArgument(const ADataType* p_a, static constexpr int NumReduce = ReduceOperations::Size();
const BDataType* p_b, static auto MakeArgument(const void* p_a,
CDataType* p_c, const void* p_b,
DPtrsGlobal p_dxs, const void* p_bias,
index_t MRaw, std::array<const void*, 0> p_ds,
index_t NRaw, void* p_c,
index_t KRaw, std::array<void*, NumReduce> p_reduces,
index_t StrideA, ck::index_t M,
index_t StrideB, ck::index_t N,
index_t StrideC, ck::index_t K,
AElementwiseOperation a_element_op, ck::index_t StrideA,
BElementwiseOperation b_element_op, ck::index_t StrideB,
CElementwiseOperation c_element_op, ck::index_t StrideC,
DxsInElementwiseOperation dxs_in_element_op, std::array<ck::index_t, 0> StrideDs,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, 3> gemm_element_ops,
index_t BatchCount) std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch)
{ {
return Argument{p_a, (void)p_bias;
p_b, (void)p_ds;
p_c, (void)StrideDs;
p_dxs, (void)d_element_ops;
MRaw,
NRaw, ReducePtrsGlobal reduce_tuple = generate_tuple(
KRaw, [&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
BatchCount}; Batch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -862,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -862,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
index_t MRaw, ck::index_t M,
index_t NRaw, ck::index_t N,
index_t KRaw, ck::index_t K,
index_t StrideA, ck::index_t StrideA,
index_t StrideB, ck::index_t StrideB,
index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, 0> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, 0> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_op,
index_t BatchCount) override index_t Batch = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs)); (void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
dxs_tuple, reduce_tuple,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
BatchCount); Batch);
} }
// polymorphic // polymorphic
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -7,7 +10,7 @@ ...@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -149,7 +152,7 @@ template <typename ADataType, ...@@ -149,7 +152,7 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl struct DeviceBatchedGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl ...@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_k0_m_k1_{ a_grid_desc_k0_m_k1_{
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
b_grid_desc_k0_n_k1_{ b_grid_desc_k0_n_k1_{
...@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl ...@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl ...@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl ...@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl ...@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl ...@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) index_t Batch)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl ...@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
BatchCount}; Batch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl ...@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) override index_t Batch) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl ...@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
BatchCount); Batch);
} }
// polymorphic // polymorphic
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -6,6 +9,7 @@ ...@@ -6,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
...@@ -22,7 +26,7 @@ template <typename ADataType, ...@@ -22,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment