Unverified Commit 4173b984 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-756

parents 6de7d10d 85e2e1e2
......@@ -36,6 +36,7 @@ using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
using I8_Tuple = ck::Tuple<I8>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
......
......@@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
......@@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
......@@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
......@@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
......@@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
......@@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
......@@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
......@@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
......
......@@ -53,6 +53,11 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple<
// ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instances
// TODO: Change to ScalarPerVector = 1 when inner_product<half_t, half_t, float> will be supported
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 2, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 1, 2>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
>;
......@@ -71,6 +76,10 @@ using device_grouped_conv2d_fwd_dl_f32_instances = std::tuple<
// ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instances
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
>;
......
......@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::f8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f8_t>(tmp);
}
};
#endif
template <typename T>
struct GeneratorTensor_3
{
......@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::f8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f8_t>(fp32_tmp);
}
};
#endif
template <typename T>
struct GeneratorTensor_4
{
......
......@@ -31,10 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment