Commit 80f038a0 authored by root's avatar root
Browse files

Post-merge fixes

parent bb1f8082
......@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
......@@ -48,11 +49,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val);
C(m, n, k) = ck::type_convert<ctype>(c_val);
}
}
......
......@@ -93,39 +93,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto ScalarPerVector = Number<4>{};
static constexpr auto MPerThread = Number<4>{};
static constexpr auto AScalarPerVector = Number<4>{};
static constexpr auto BScalarPerVector = Number<4>{};
static constexpr auto CScalarPerVector = Number<4>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
return desc_m_pad;
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t blockSize)
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<2>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<2>{});
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<1>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<1>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
......@@ -395,7 +393,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -492,13 +490,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
c_grid_desc_m0_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
c_grid_desc_m0_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
}
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
......@@ -516,7 +514,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
GridDesc_M0 c_grid_desc_m0_;
CGridDesc_M c_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -556,27 +554,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType,
CDataType,
CDataType,
GridDesc_M0,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Add,
ScalarPerVector>;
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract,
ScalarPerVector>;
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Add>;
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
CDataType,
CDataType,
CDataType,
GridDesc_M0,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
......@@ -637,9 +649,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_real_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
ave_time +=
......@@ -685,9 +697,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_imag_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Add{});
}
else
......@@ -748,9 +760,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_real_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
ave_time +=
......@@ -796,9 +808,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.p_c_grid_imag_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Add{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment