Commit 165e30cd authored by Chao Liu's avatar Chao Liu
Browse files

adding bias add

parent d5679ea6
......@@ -12,13 +12,14 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -32,9 +33,13 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
......@@ -48,75 +53,19 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
p_c1_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r5(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>(
cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1));
const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>(
cast_pointer_to_generic_address_space(p_b_grid_desc_k0_n_k1));
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
*reinterpret_cast<const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2*>(
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
template <index_t BlockSize,
typename FloatAB,
......@@ -126,6 +75,8 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -281,8 +232,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
return has_main_k0_block_loop;
}
// TODO fix this
template <typename CGridDesc_M_N_any>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
......@@ -369,6 +322,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop>
......@@ -376,10 +336,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
......@@ -392,6 +356,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
......
......@@ -14,7 +14,9 @@
#include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * D[m] + gamma * E[m, n]
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct PassThrough
{
......@@ -175,9 +177,19 @@ int main(int argc, char* argv[])
Tensor<BDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<BDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// C0[m ,n]
Tensor<BDataType> c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// C1[m]
Tensor<CDataType> c1_m_n(HostTensorDescriptor(
std::vector<std::size_t>({static_cast<std::size_t>(M), static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1, 0})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl;
std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl;
switch(init_method)
{
......@@ -185,19 +197,27 @@ int main(int argc, char* argv[])
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
c0_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
c1_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
c0_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace());
DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
// do GEMM
auto gemm = typename DeviceGemmInstance<ADataType,
......@@ -214,6 +234,8 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_m_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c1_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
......
#ifndef DEVICE_GEMM_HPP
#define DEVICE_GEMM_HPP
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -129,6 +129,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// hardcoding
// TODO: fix this
using C1GridDesc_M_N =
decltype(make_naive_tensor_descriptor(make_tuple(1, 1), make_tuple(I1, I0)));
// TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks =
......@@ -179,6 +185,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -221,6 +229,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument
......@@ -229,6 +243,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const CDataType* p_c0_grid,
const CDataType* p_c1_grid,
index_t M,
index_t N,
index_t K,
......@@ -243,10 +259,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c0_grid_desc_m_n_{},
c1_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -261,12 +283,27 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
c_grid_desc_m_n_ =
DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC);
// assume C0 has same layout as C
// TODO: fix this
c0_grid_desc_m_n_ =
DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC);
// hardcoding C1 layout
// TODO: fix this
c1_grid_desc_m_n_ = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I0));
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -275,10 +312,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const CDataType* p_c0_grid_;
const CDataType* p_c1_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -305,6 +348,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
......@@ -335,6 +384,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -349,9 +402,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -367,6 +424,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -381,9 +442,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -424,6 +489,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const CDataType* p_c0,
const CDataType* p_c1,
index_t M,
index_t N,
index_t K,
......@@ -437,6 +504,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return Argument{p_a,
p_b,
p_c,
p_c0,
p_c1,
M,
N,
K,
......@@ -456,6 +525,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
index_t M,
index_t N,
index_t K,
......@@ -469,6 +540,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const CDataType*>(p_c0),
static_cast<const CDataType*>(p_c1),
M,
N,
K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment