Commit 72c9f129 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 241c261f ded0d83d
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck { namespace ck {
...@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory< ...@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
} }
}; };
using CombConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu;
#ifdef CK_ENABLE_FP8
void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScaleRelu,
F8,
F8>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DLayouts,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DDataTypes,
typename OutDataType,
typename AComputeType,
typename BComputeType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScaleRelu,
AComputeType,
BComputeType>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScaleRelu,
AComputeType,
BComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, f8_t> && is_same_v<WeiDataType, f8_t> &&
is_same_v<OutDataType, F32> && is_same_v<AComputeType, f8_t> &&
is_same_v<BComputeType, f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
op_ptrs);
}
#endif
}
return op_ptrs;
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances( ...@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&); DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&);
#endif #endif
#ifdef CK_ENABLE_FP8
void add_device_permute_scale_6d_f32_f8_instances(
std::vector<std::unique_ptr<
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F8>, element_wise::Scale, 6>>>&);
#endif
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
...@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory< ...@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_permute_scale_6d_f16_instances(op_ptrs); add_device_permute_scale_6d_f16_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F8>>)
{
add_device_permute_scale_6d_f32_f8_instances(op_ptrs);
}
#endif #endif
} }
return op_ptrs; return op_ptrs;
......
...@@ -10,6 +10,7 @@ namespace tensor_operation { ...@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances = ...@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
#if 0 #if 0
// Disabled instances to improve compilation time // Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters // They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances = ...@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances = ...@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>; >;
template <index_t NDims, template <index_t NDims,
...@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
#if 0 #if 0
// Disabled instances to improve compilation time // Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters // They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
#endif #endif
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
...@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>; >;
#ifdef CK_ENABLE_FP8
template <index_t NDims,
typename ElementwiseOp>
using device_permute_scale_f32_f8_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>;
#endif
// clang-format on // clang-format on
} // namespace instance } // namespace instance
......
...@@ -14,15 +14,24 @@ namespace device { ...@@ -14,15 +14,24 @@ namespace device {
namespace instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex // InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>>&);
// clang-format on // clang-format on
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -146,7 +146,7 @@ check_err(const Range& out, ...@@ -146,7 +146,7 @@ check_err(const Range& out,
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min(); double max_err = NumericLimits<ranges::range_value_t<Range>>::Min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
...@@ -178,7 +178,9 @@ check_err(const Range& out, ...@@ -178,7 +178,9 @@ check_err(const Range& out,
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> && std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>) !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
!std::is_same_v<ranges::range_value_t<Range>, f8_t> &&
!std::is_same_v<ranges::range_value_t<Range>, bf8_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t> || std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif #endif
...@@ -270,7 +272,8 @@ check_err(const Range& out, ...@@ -270,7 +272,8 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
} }
return res; return res;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old() ...@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{ {
return {0, 1, 2, 3}; return {0, 1, 2, 3};
} }
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKW>)
{
return {1, 0, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKHW>)
{
return {1, 0, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCDHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKDHW>)
{
return {1, 0, 2, 3, 4, 5};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> || else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> || ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>) ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
...@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa ...@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param.input_spatial_lengths_.begin() + param.num_dim_spatial_); param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
// separate from legacy code above // separate from legacy code above
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> || else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> || ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>) ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
...@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP ...@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param.output_spatial_lengths_.begin(), param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_); param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
// separate from legacy code above
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> || else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> || ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>) ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
......
...@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES ...@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES list(APPEND GEMM_INSTANCES
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
......
...@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES ...@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
) )
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES}) add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})
...@@ -4,14 +4,13 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES) ...@@ -4,14 +4,13 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
) )
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES}) add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES})
...@@ -46,8 +46,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -46,8 +46,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances<GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment