Commit 208ac1a5 authored by myamlak's avatar myamlak
Browse files

Consuming binary ops to do A+B / A-B

parent 5e104742
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "binary_element_wise_operation.hpp"
#include "gridwise_binary_elementwise_1d.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck { namespace ck {
...@@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto ScalarPerVector = Number<4>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<2>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<2>{});
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
}
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
...@@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
c_grid_desc_m0_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
c_grid_desc_m0_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
}
} }
// private: // private:
...@@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
GridDesc_M0 c_grid_desc_m0_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Substract = ck::tensor_operation::binary_element_wise::Substract;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
Add,
ScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
Substract,
ScalarPerVector>;
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
Add>;
const auto substract_kernel = kernel_elementwise_1d<GridwiseBinSubstract,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
Substract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
...@@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_real = aux - aux_2 needed here!!! // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_real_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
Substract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = aux + aux_2 needed here!!! // c_imag = aux + aux_2
ave_time += launch_and_time_kernel(stream_config,
add_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_imag_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
Add{});
} }
else else
{ {
...@@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// // c_real = aux - aux_2 needed here!!! // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_real_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
Substract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = aux + aux_2 needed here!!! // c_imag = aux + aux_2
ave_time += launch_and_time_kernel(stream_config,
add_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_imag_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
Add{});
} }
return ave_time; return ave_time;
......
...@@ -12,6 +12,39 @@ struct Add ...@@ -12,6 +12,39 @@ struct Add
{ {
dst = src1 + src2; dst = src1 + src2;
} }
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 + src2;
}
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
dst = src1 + src2;
}
};
struct Substract
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
dst = src1 - src2;
}
}; };
} // namespace binary_element_wise } // namespace binary_element_wise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment