Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3d61f89a
Unverified
Commit
3d61f89a
authored
Aug 21, 2024
by
Illia Silin
Committed by
GitHub
Aug 21, 2024
Browse files
Merge pull request #134 from ROCm/merge_from_public
Merge from public
parents
c160c6cf
4558a3f8
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1829 additions
and
157 deletions
+1829
-157
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
...k/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
+226
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+166
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+111
-106
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
...y/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
+457
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
...ice_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
+78
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
+93
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+74
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
...on_instance/gpu/grouped_convolution_forward_convscale.hpp
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp
...nstance/gpu/grouped_convolution_forward_convscale_add.hpp
+105
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
...stance/gpu/grouped_convolution_forward_convscale_relu.hpp
+189
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
...ance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
+112
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
...k/library/tensor_operation_instance/gpu/permute_scale.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
...ance/gpu/permute_scale/device_permute_scale_instances.hpp
+46
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
...uce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
+18
-9
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+4
-2
library/include/ck/library/utility/convolution_parameter.hpp
library/include/ck/library/utility/convolution_parameter.hpp
+30
-18
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+18
-3
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+48
-15
No files found.
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
A0DataType
,
typename
A1DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
A0DataType
,
f8_t
>
&&
is_same_v
<
B0DataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -309,13 +254,8 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -376,44 +316,100 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mnpadding
_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_default_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v1_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v2_default
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_nkpadding
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_kpadding
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
...
...
@@ -481,28 +477,20 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
@@ -511,21 +499,14 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -557,21 +538,14 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -596,7 +570,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
...
...
@@ -633,23 +607,54 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
}
}
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
#ifdef CK_ENABLE_FP16
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
3d61f89a
...
...
@@ -40,10 +40,10 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Num
Batch
|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Num
Groups
|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version|
|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_binary_outelementop_instance.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
F8
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_large_tensor_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_large_tensor_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
View file @
3d61f89a
...
...
@@ -147,6 +147,80 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple<
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
#endif
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
3d61f89a
...
...
@@ -17,6 +17,7 @@
#endif
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_large_tensor.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
...
...
@@ -199,6 +200,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
op_ptrs
);
...
...
@@ -212,6 +215,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -227,6 +232,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -284,6 +291,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
op_ptrs
);
...
...
@@ -338,6 +347,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -353,6 +364,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
View file @
3d61f89a
...
...
@@ -70,6 +70,22 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_ins
ConvScale
,
F8
,
BF8
>>>&
instances
);
void
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScale
,
BF8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
...
...
@@ -147,6 +163,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
InDataType
,
bf8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
bf8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ConvScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleAdd
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
ConvScaleAdd
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleAdd
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleAdd
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleRelu
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
namespace
ew
=
ck
::
tensor_operation
::
element_wise
;
using
CombConvScaleRelu
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
ew
::
Relu
>
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
0 → 100644
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
View file @
3d61f89a
...
...
@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_permute_scale_6d_f32_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
...
...
@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F8
>>
)
{
add_device_permute_scale_6d_f32_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
View file @
3d61f89a
...
...
@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#ifdef CK_ENABLE_FP8
template
<
index_t
NDims
,
typename
ElementwiseOp
>
using
device_permute_scale_f32_f8_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
256
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
256
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
128
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
64
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#endif
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
View file @
3d61f89a
...
...
@@ -14,15 +14,24 @@ namespace device {
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/utility/check_err.hpp
View file @
3d61f89a
...
...
@@ -146,7 +146,7 @@ check_err(const Range& out,
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
n
umeric
_l
imits
<
ranges
::
range_value_t
<
Range
>>::
m
in
();
double
max_err
=
N
umeric
L
imits
<
ranges
::
range_value_t
<
Range
>>::
M
in
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
...
...
@@ -178,7 +178,9 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
...
...
library/include/ck/library/utility/convolution_parameter.hpp
View file @
3d61f89a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -31,23 +31,35 @@ struct ConvParam
const
std
::
vector
<
ck
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
right_pads
);
ck
::
index_t
num_dim_spatial_
;
ck
::
index_t
G_
;
ck
::
index_t
N_
;
ck
::
index_t
K_
;
ck
::
index_t
C_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
std
::
vector
<
ck
::
index_t
>
GetOutputSpatialLengths
()
const
;
ConvParam
(
ck
::
long_index_t
n_dim
,
ck
::
long_index_t
group_count
,
ck
::
long_index_t
n_batch
,
ck
::
long_index_t
n_out_channels
,
ck
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
long_index_t
>&
right_pads
);
ck
::
long_index_t
num_dim_spatial_
;
ck
::
long_index_t
G_
;
ck
::
long_index_t
N_
;
ck
::
long_index_t
K_
;
ck
::
long_index_t
C_
;
std
::
vector
<
ck
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck
::
long_index_t
>
GetOutputSpatialLengths
()
const
;
std
::
size_t
GetFlops
()
const
;
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
3d61f89a
...
...
@@ -96,9 +96,16 @@ struct HostTensorDescriptor
this
->
CalculateStrides
();
}
HostTensorDescriptor
(
const
std
::
initializer_list
<
ck
::
long_index_t
>&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
template
<
typename
Lengths
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>>>
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
||
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
ck
::
long_index_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
...
...
@@ -114,11 +121,19 @@ struct HostTensorDescriptor
{
}
HostTensorDescriptor
(
const
std
::
initializer_list
<
ck
::
long_index_t
>&
lens
,
const
std
::
initializer_list
<
ck
::
long_index_t
>&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
template
<
typename
Lengths
,
typename
Strides
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>>>
(
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>
)
||
(
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
ck
::
long_index_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
ck
::
long_index_t
>
)
>>
HostTensorDescriptor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
3d61f89a
...
...
@@ -64,6 +64,13 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build mha instances if gfx94 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
set
(
INST_OBJ
)
...
...
@@ -74,9 +81,11 @@ function(add_instance_library INSTANCE_NAME)
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
gfx1200 gfx1201
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
...
...
@@ -86,7 +95,29 @@ function(add_instance_library INSTANCE_NAME)
list
(
APPEND INST_OBJ
${
source
}
)
endforeach
()
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
INST_OBJ
}
)
# Allow comparing floating points directly in order to check sentinel values
if
(
${
INSTANCE_NAME
}
STREQUAL
"device_mha_instance"
)
if
(
NOT DEFINED FMHA_FWD_FAST_EXP2
)
set
(
FMHA_FWD_FAST_EXP2 true
)
endif
()
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
target_compile_options
(
device_mha_instance PRIVATE
${
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
}
)
endif
()
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
# flags to compress the library
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
message
(
"Adding --offload-compress flag for
${
INSTANCE_NAME
}
"
)
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
endif
()
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
set
(
result 0
)
...
...
@@ -286,20 +317,22 @@ if(CK_DEVICE_CONV_INSTANCES)
)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_mha_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/mha>
)
rocm_install
(
TARGETS device_mha_operations
EXPORT device_mha_operationsTargets
)
rocm_install
(
EXPORT device_mha_operationsTargets
FILE composable_kerneldevice_mha_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
set
(
gpu_list
${
INST_TARGETS
}
)
list
(
FILTER gpu_list INCLUDE REGEX
"^gfx94"
)
if
(
gpu_list
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
rocm_install
(
TARGETS device_mha_operations
EXPORT device_mha_operationsTargets
)
rocm_install
(
EXPORT device_mha_operationsTargets
FILE composable_kerneldevice_mha_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
endif
()
if
(
CK_DEVICE_CONTRACTION_INSTANCES
)
add_library
(
device_contraction_operations STATIC
${
CK_DEVICE_CONTRACTION_INSTANCES
}
)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment