Commit 80f038a0 authored by root's avatar root
Browse files

Post-merge fixes

parent bb1f8082
...@@ -17,7 +17,8 @@ using ABDataType = F16; ...@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
...@@ -48,11 +49,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C, ...@@ -48,11 +49,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
for(std::size_t n = 0; n < shape[1]; ++n) for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k) for(std::size_t k = 0; k < shape[2]; ++k)
{ {
ComputeDataType a_val = static_cast<ComputeDataType>(A(m)); ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k)); ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0; ComputeDataType c_val = 0;
functor(c_val, a_val, b_val); functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val); C(m, n, k) = ck::type_convert<ctype>(c_val);
} }
} }
......
...@@ -93,39 +93,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -93,39 +93,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto ScalarPerVector = Number<4>{}; static constexpr auto MPerThread = Number<4>{};
static constexpr auto AScalarPerVector = Number<4>{};
static constexpr auto BScalarPerVector = Number<4>{};
static constexpr auto CScalarPerVector = Number<4>{};
template <typename Desc_M0> template <typename Desc_M>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{ {
const auto m0 = desc_m0.GetLength(I0); const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector; const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0; const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m0_pad = const auto desc_m_pad =
transform_tensor_descriptor(desc_m0, transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m0, pad)), make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return desc_m0_pad; return desc_m_pad;
} }
static auto MakeDescriptor_M0(const std::vector<int>& shape, static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<int>& stride, const std::vector<index_t>& strides,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<2>{}); auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<1>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<2>{}); auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<1>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto desc_m0 = transform_tensor_descriptor( return PadDescriptor_M_1d(desc, gridSize, blockSize);
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
} }
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
...@@ -395,7 +393,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -395,7 +393,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
...@@ -492,13 +490,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -492,13 +490,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
c_grid_desc_m0_ = c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
c_grid_desc_m0_ = c_grid_desc_m_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
} }
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
...@@ -516,7 +514,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -516,7 +514,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
GridDesc_M0 c_grid_desc_m0_; CGridDesc_M c_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -556,27 +554,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -556,27 +554,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
GridDesc_M0, CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Add, Add,
ScalarPerVector>; MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType, using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
GridDesc_M0, CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract, Substract,
ScalarPerVector>; MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd, const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
GridDesc_M0, CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Add>; Add>;
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract, const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
GridDesc_M0, CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract>; Substract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
...@@ -637,9 +649,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -637,9 +649,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_, arg.p_aux_grid_,
arg.p_aux_2_grid_, arg.p_aux_2_grid_,
arg.p_c_grid_real_, arg.p_c_grid_real_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
Substract{}); Substract{});
ave_time += ave_time +=
...@@ -685,9 +697,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -685,9 +697,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_, arg.p_aux_grid_,
arg.p_aux_2_grid_, arg.p_aux_2_grid_,
arg.p_c_grid_imag_, arg.p_c_grid_imag_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
Add{}); Add{});
} }
else else
...@@ -748,9 +760,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -748,9 +760,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_, arg.p_aux_grid_,
arg.p_aux_2_grid_, arg.p_aux_2_grid_,
arg.p_c_grid_real_, arg.p_c_grid_real_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
Substract{}); Substract{});
ave_time += ave_time +=
...@@ -796,9 +808,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -796,9 +808,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.p_aux_grid_, arg.p_aux_grid_,
arg.p_aux_2_grid_, arg.p_aux_2_grid_,
arg.p_c_grid_imag_, arg.p_c_grid_imag_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
Add{}); Add{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment