"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "c63906a6cd98f020026e6cd1d35498c46bf74ae3"
Commit 96c73d70 authored by Chao Liu's avatar Chao Liu
Browse files

add missing type convert

parent 2d35fac0
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
ADataType, // ADataType <ALayout, // typename ALayout,
BDataType, // BDataType BLayout, // typename BLayout,
CDataType, // CDataType CLayout, // typename CLayout,
AccDataType, // AccDataType ADataType, // typename ADataType,
CDataType, // CShuffleDataType BDataType, // typename BDataType,
ALayout, // ALayout CDataType, // typename CDataType,
BLayout, // BLayout AccDataType, // typename GemmAccDataType,
CLayout, // CLayout CDataType, // typename CShuffleDataType,
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation,
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation,
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation,
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec,
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage,
128, // NPerBlock 256, // index_t BlockSize,
32, // KPerBlock 256, // index_t MPerBlock,
8, // AK1 128, // index_t NPerBlock,
8, // BK1 32, // index_t KPerBlock,
32, // MPerXDL 8, // index_t AK1,
32, // NPerXDL 8, // index_t BK1,
4, // MXdlPerWave 32, // index_t MPerXDL,
2, // NXdlPerWave 32, // index_t NPerXDL,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
8, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
8, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim,
true, // ABlockLdsAddExtraM 8, // index_t ABlockTransferSrcScalarPerVector,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 8, // index_t ABlockTransferDstScalarPerVector_AK1,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // bool ABlockLdsExtraM,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
8, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
8, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim,
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector,
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // CShuffleNXdlPerWavePerShuffle 1, // bool BBlockLdsExtraN,
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle,
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali ...@@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
#if 0 #if 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -20,64 +19,63 @@ ...@@ -20,64 +19,63 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int32_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = int8_t;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout,
BDataType, // BDataType BLayout, // typename BLayout,
CDataType, // CDataType CLayout, // typename CLayout,
AccDataType, // AccDataType ADataType, // typename ADataType,
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType,
ALayout, // ALayout CDataType, // typename CDataType,
BLayout, // BLayout AccDataType, // typename GemmAccDataType,
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType,
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation,
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation,
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation,
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec,
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage,
128, // NPerBlock 256, // index_t BlockSize,
64, // KPerBlock 256, // index_t MPerBlock,
16, // AK1 128, // index_t NPerBlock,
16, // BK1 64, // index_t KPerBlock,
32, // MPerXDL 16, // index_t AK1,
32, // NPerXDL 16, // index_t BK1,
4, // MXdlPerWave 32, // index_t MPerXDL,
2, // NXdlPerWave 32, // index_t NPerXDL,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim,
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // bool ABlockLdsExtraM,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim,
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector,
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // CShuffleNXdlPerWavePerShuffle 1, // bool BBlockLdsExtraN,
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle,
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -51,7 +51,7 @@ template <typename SrcData, ...@@ -51,7 +51,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename DstElementwiseOperation, typename ElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3( __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const DstDesc& dst_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const DstElementwiseOperation& dst_element_op) const ElementwiseOperation& element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst_element_op_{dst_element_op} element_op_{element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData dst_v; SrcData v;
// apply element-wise operation // apply element-wise operation
dst_element_op_(dst_v, src_buf[Number<src_offset>{}]); element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert // apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(dst_v); dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
const DstElementwiseOperation dst_element_op_; const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3 }; // namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume: // Assume:
......
...@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation // apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) { static_for<0, ScalarPerVector, 1>{}([&](auto i) {
element_op_(dst_vector_container.template AsType<DstData>()(i), SrcData v;
src_vector_container.template AsType<SrcData>()[i]);
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
}); });
const bool is_dst_valid = const bool is_dst_valid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment