Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -7,7 +10,7 @@ ...@@ -7,7 +10,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -32,7 +35,7 @@ template <typename ADataType, ...@@ -32,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector, index_t DScalarPerVector,
index_t EScalarPerVector, index_t EScalarPerVector,
index_t FScalarPerVector> index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator ...@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(std::array<const void*, 5> p_inputs,
const BDataType* p_b, std::array<void*, 1> p_outputs,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
...@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator ...@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides, std::vector<index_t> f_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_inputs[0]),
p_b, static_cast<const BDataType*>(p_inputs[1]),
p_c, static_cast<const CDataType*>(p_inputs[2]),
p_d, static_cast<const DDataType*>(p_inputs[3]),
p_e, static_cast<const EDataType*>(p_inputs[4]),
p_f, static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, a_strides,
b_strides, b_strides,
...@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator ...@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor}; functor};
} }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 5> p_inputs,
const void* p_c, std::array<void*, 1> p_outputs,
const void* p_d, std::vector<index_t> lengths,
const void* p_e, std::vector<std::vector<index_t>> input_strides,
void* p_f, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> lengths, ElementwiseFunctor functor) override
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_c), static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_d), static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_e), static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_f), static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, input_strides[2],
d_strides, input_strides[3],
e_strides, input_strides[4],
f_strides, output_strides[0],
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
}; {
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <string> #include <string>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmPtr = std::unique_ptr<
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -7,7 +10,7 @@ ...@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -149,7 +152,7 @@ template <typename ADataType, ...@@ -149,7 +152,7 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl struct DeviceBatchedGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl ...@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_k0_m_k1_{ a_grid_desc_k0_m_k1_{
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
b_grid_desc_k0_n_k1_{ b_grid_desc_k0_n_k1_{
...@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl ...@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl ...@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl ...@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl ...@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl ...@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) index_t Batch)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl ...@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
BatchCount}; Batch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl ...@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t BatchCount) override index_t Batch) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl ...@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
BatchCount); Batch);
} }
// polymorphic // polymorphic
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -6,6 +9,7 @@ ...@@ -6,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
...@@ -22,7 +26,7 @@ template <typename ADataType, ...@@ -22,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment