Commit 2fd6c6d4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents c32d3448 6651a124
...@@ -22,14 +22,19 @@ namespace wrapper { ...@@ -22,14 +22,19 @@ namespace wrapper {
// Disable from doxygen docs generation // Disable from doxygen docs generation
/// @cond /// @cond
// forward declaration // forward declaration
template <typename Shape, typename UnnestedDescriptorType> template <typename Shape, typename UnrolledDescriptorType>
struct Layout; struct Layout;
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace { namespace {
// Generate packed (column-major) strides if not passed /**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape) GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
...@@ -50,8 +55,15 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape) ...@@ -50,8 +55,15 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
Number<decltype(unrolled_shape)::Size()>{}); Number<decltype(unrolled_shape)::Size()>{});
} }
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template <typename LayoutShape, typename LayoutStrides> template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape, __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
const LayoutStrides& strides) const LayoutStrides& strides)
{ {
const auto unrolled_shape = UnrollNestedTuple(shape); const auto unrolled_shape = UnrollNestedTuple(shape);
...@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap ...@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap
template <typename Shape, typename Strides> template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{ {
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{})); using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides)); return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
} }
/** /**
...@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides ...@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape> template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape) __host__ __device__ constexpr auto make_layout(const Shape& shape)
{ {
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{})); using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{})); return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
} }
// Layout helpers // Layout helpers
// get // get
// Get dim (could be returned from get with empty Idxs)
/** /**
* \private * \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/ */
template <typename T> template <typename T>
__host__ __device__ T constexpr get(const T& dim) __host__ __device__ T constexpr get(const T& dim)
...@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout) ...@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
}, },
Number<old_shape_dims>{}); Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnnestedDescriptor(); const auto& flatten_desc = layout.GetUnrolledDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc); return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
} }
...@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem) ...@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
} }
// size // size
// Get dim size (could be returned from get function)
/** /**
* \private * \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/ */
template <typename T> template <typename T>
__host__ __device__ T constexpr size(const T& dim) __host__ __device__ T constexpr size(const T& dim)
...@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim) ...@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of. * \param layout Layout to get Shape of.
* \return Requsted length. * \return Requsted length.
*/ */
template <index_t idx, typename Shape, typename UnnestedDescriptorType> template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout) __host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{ {
return layout.template GetLength<idx>(); return layout.template GetLength<idx>();
} }
...@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape) ...@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size. * \param layout Layout to calculate shape size.
* \return Requsted size. * \return Requsted size.
*/ */
template <typename Shape, typename UnnestedDescriptorType> template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout) __host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{ {
return layout.GetLengths(); return layout.GetLengths();
} }
...@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem) ...@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank. * \param layout Layout to calculate rank.
* \return Requsted rank. * \return Requsted rank.
*/ */
template <typename Shape, typename UnnestedDescriptorType> template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout) rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
{ {
return Shape::Size(); return Shape::Size();
} }
...@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t ...@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/** /**
* \private * \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t rank(const Number<IDim>&) __host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
{ {
return 1; return 1;
} }
/** /**
* \private * \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/ */
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } __host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
/** /**
* \brief Hierarchical rank. * \brief Hierarchical rank.
...@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem) ...@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth. * \param layout Layout to calculate depth.
* \return Requsted depth. * \return Requsted depth.
*/ */
template <typename Shape, typename UnnestedDescriptorType> template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout) __host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
{ {
const auto& shape = layout.GetShape(); const auto& shape = layout.GetShape();
return TupleDepth(shape); return TupleDepth(shape);
...@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple) ...@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/** /**
* \private * \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t depth(const Number<IDim>&) __host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
{ {
return 0; return 0;
} }
/** /**
* \private * \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/ */
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } __host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
/** /**
* \brief Hierarchical depth. * \brief Hierarchical depth.
......
This diff is collapsed.
...@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error(
"Conv_bwd_data: number of dimensions must be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Conv_fwd: number of dimensions must be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0; AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ComputeTypeA v_a;
ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>) ck::tensor_operation::element_wise::ConvertBF16RTN>)
...@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
CDataType v_c; CDataType v_c = 0;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
return 0; return 0;
} }
throw std::runtime_error("Img2Col: number of dimensions should be between 1 and 3.");
return 1;
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
...@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory< ...@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
return op_ptrs; return op_ptrs;
} }
}; };
#endif
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt ...@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
return op_ptrs; return op_ptrs;
} }
}; };
#endif
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< ...@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves // pipeline v1, 2 waves
, ,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< ...@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves // pipeline v1, 2 waves
, ,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< ...@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves // pipeline v1, 2 waves
, ,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment