Commit 165e30cd authored by Chao Liu's avatar Chao Liu
Browse files

adding bias add

parent d5679ea6
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -32,9 +33,13 @@ __global__ void ...@@ -32,9 +33,13 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -48,75 +53,19 @@ __global__ void ...@@ -48,75 +53,19 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid,
p_c1_grid,
p_shared_block, p_shared_block,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r5(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>(
cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1));
const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>(
cast_pointer_to_generic_address_space(p_b_grid_desc_k0_n_k1));
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
*reinterpret_cast<const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2*>(
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -126,6 +75,8 @@ template <index_t BlockSize, ...@@ -126,6 +75,8 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -281,8 +232,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -281,8 +232,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
// TODO fix this
template <typename CGridDesc_M_N_any>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n)
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -369,6 +322,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -369,6 +322,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
...@@ -376,10 +336,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -376,10 +336,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -392,6 +356,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -392,6 +356,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N] // divide block work by [M, N]
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp" #include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * D[m] + gamma * E[m, n] // C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct PassThrough struct PassThrough
{ {
...@@ -175,9 +177,19 @@ int main(int argc, char* argv[]) ...@@ -175,9 +177,19 @@ int main(int argc, char* argv[])
Tensor<BDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<BDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<BDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<BDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// C0[m ,n]
Tensor<BDataType> c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// C1[m]
Tensor<CDataType> c1_m_n(HostTensorDescriptor(
std::vector<std::size_t>({static_cast<std::size_t>(M), static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1, 0})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl;
std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -185,19 +197,27 @@ int main(int argc, char* argv[]) ...@@ -185,19 +197,27 @@ int main(int argc, char* argv[])
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
c0_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
c1_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
c0_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace());
DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
// do GEMM // do GEMM
auto gemm = typename DeviceGemmInstance<ADataType, auto gemm = typename DeviceGemmInstance<ADataType,
...@@ -214,6 +234,8 @@ int main(int argc, char* argv[]) ...@@ -214,6 +234,8 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_m_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c1_m_n_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
......
#ifndef DEVICE_GEMM_HPP
#define DEVICE_GEMM_HPP
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -129,6 +129,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -129,6 +129,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// hardcoding
// TODO: fix this
using C1GridDesc_M_N =
decltype(make_naive_tensor_descriptor(make_tuple(1, 1), make_tuple(I1, I0)));
// TODO remove these hacks // TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks = static constexpr auto a_k0_m_k1_grid_step_hacks =
...@@ -179,6 +185,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -179,6 +185,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -221,6 +229,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -221,6 +229,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument // Argument
...@@ -229,6 +243,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -229,6 +243,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const CDataType* p_c0_grid,
const CDataType* p_c1_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -243,10 +259,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -243,10 +259,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c0_grid_desc_m_n_{},
c1_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -261,12 +283,27 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -261,12 +283,27 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
c_grid_desc_m_n_ = c_grid_desc_m_n_ =
DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC); DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC);
// assume C0 has same layout as C
// TODO: fix this
c0_grid_desc_m_n_ =
DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC);
// hardcoding C1 layout
// TODO: fix this
c1_grid_desc_m_n_ = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I0));
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{ {
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -275,10 +312,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -275,10 +312,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const CDataType* p_c0_grid_;
const CDataType* p_c1_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -305,6 +348,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -305,6 +348,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
...@@ -335,6 +384,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -335,6 +384,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -349,9 +402,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -349,9 +402,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -367,6 +424,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -367,6 +424,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl_two_extra_source_reduce::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -381,9 +442,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -381,9 +442,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -424,6 +489,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -424,6 +489,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const CDataType* p_c0,
const CDataType* p_c1,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -437,6 +504,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -437,6 +504,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
p_c0,
p_c1,
M, M,
N, N,
K, K,
...@@ -456,6 +525,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -456,6 +525,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0,
const void* p_c1,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -469,6 +540,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -469,6 +540,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const CDataType*>(p_c0),
static_cast<const CDataType*>(p_c1),
M, M,
N, N,
K, K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment