Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72c9f129
Commit
72c9f129
authored
Sep 20, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
241c261f
ded0d83d
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
402 additions
and
190 deletions
+402
-190
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
...stance/gpu/grouped_convolution_forward_convscale_relu.hpp
+83
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+112
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
...k/library/tensor_operation_instance/gpu/permute_scale.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
...ance/gpu/permute_scale/device_permute_scale_instances.hpp
+52
-6
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
...uce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
+18
-9
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+7
-4
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
...ary/utility/convolution_host_tensor_descriptor_helper.hpp
+41
-1
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
...ensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
...ration_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
+3
-4
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+1
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ltiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
+0
-32
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
+0
-32
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
...ly_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
+0
-33
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+11
-11
No files found.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
View file @
72c9f129
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/
unary
_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/
combined
_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
...
...
@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
}
};
using
CombConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ScaleScaleRelu
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
0 → 100644
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
View file @
72c9f129
...
...
@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_permute_scale_6d_f32_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
...
...
@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F8
>>
)
{
add_device_permute_scale_6d_f32_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
View file @
72c9f129
...
...
@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
#if 0
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
...
@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
...
@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
template
<
index_t
NDims
,
...
...
@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
#if 0
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
...
@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
...
@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
#endif
#endif
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
...
...
@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#ifdef CK_ENABLE_FP8
template
<
index_t
NDims
,
typename
ElementwiseOp
>
using
device_permute_scale_f32_f8_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
256
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
256
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
128
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
64
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#endif
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
View file @
72c9f129
...
...
@@ -14,15 +14,24 @@ namespace device {
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/utility/check_err.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -146,7 +146,7 @@ check_err(const Range& out,
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
n
umeric
_l
imits
<
ranges
::
range_value_t
<
Range
>>::
m
in
();
double
max_err
=
N
umeric
L
imits
<
ranges
::
range_value_t
<
Range
>>::
M
in
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
...
...
@@ -178,7 +178,9 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
...
...
@@ -270,7 +272,8 @@ check_err(const Range& out,
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
...
...
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
)
{
return
{
1
,
0
,
2
,
3
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
,
5
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GKCYX
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNKHW
>
)
...
...
@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCDHW
>
)
...
...
@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNHWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>
)
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
72c9f129
...
...
@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
View file @
72c9f129
...
...
@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_ab_scale_instance
${
GEMM_AB_SCALE_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
View file @
72c9f129
...
...
@@ -4,14 +4,13 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
72c9f129
...
...
@@ -46,8 +46,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
deleted
100644 → 0
View file @
241c261f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
<
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
deleted
100644 → 0
View file @
241c261f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
<
GemmMNPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
deleted
100644 → 0
View file @
241c261f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
View file @
72c9f129
...
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment