Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b924e330
Commit
b924e330
authored
Oct 03, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
72c9f129
9c0811f3
Changes
153
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
856 additions
and
29 deletions
+856
-29
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+34
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+16
-4
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+31
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
+1
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+14
-3
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+10
-4
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
.../library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
+80
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+151
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+34
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
...ion_instance/gpu/grouped_convolution_forward_comp_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+28
-0
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
...ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
+21
-1
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
+211
-0
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
+58
-13
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+6
-4
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
...nsor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
+8
-0
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
...g_pool2d_bwd/device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
+21
-0
No files found.
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
{
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
ck_tile
::
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
b924e330
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -24,6 +25,14 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -24,6 +25,14 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
AlignmentA
=
Problem
::
AlignmentA
;
static
constexpr
index_t
AlignmentB
=
Problem
::
AlignmentB
;
static
constexpr
index_t
AlignmentC
=
Problem
::
AlignmentC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
ck_tile
::
integer_divide_ceil
(
...
@@ -35,6 +44,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -35,6 +44,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
@@ -140,8 +154,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -140,8 +154,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
}
}
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
while
(
iCounter
>
0
)
do
{
{
// global read i + 1
// global read i + 1
a_block_tile
=
load_tile
(
a_copy_dram_window
);
a_block_tile
=
load_tile
(
a_copy_dram_window
);
...
@@ -167,8 +180,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -167,8 +180,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
iCounter
--
;
iCounter
--
;
}
}
while
(
iCounter
>
0
);
// tail
// tail
{
{
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
b924e330
...
@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
b_lds_block_desc
;
return
b_lds_block_desc
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
#elif 1
#elif 1
// fake XOR
// fake XOR
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -178,6 +205,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -178,6 +205,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
...
@@ -216,6 +245,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -216,6 +245,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
b924e330
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
b924e330
...
@@ -5,13 +5,17 @@
...
@@ -5,13 +5,17 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#define VectorLoadSize 16
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
>
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
struct
BlockGemmPipelineProblem
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
...
@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
...
@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
AlignmentC
=
kPadC
?
1
:
VectorLoadSize
/
sizeof
(
CDataType
);
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
b924e330
...
@@ -7,12 +7,18 @@
...
@@ -7,12 +7,18 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
index_t
kMPer
Tile
,
index_t
kNPerTile
,
index_t
kKPer
Tile
>
template
<
typename
Block
Tile
_
,
typename
BlockWarps_
,
typename
Warp
Tile
_
>
struct
TileGemmShape
struct
TileGemmShape
{
{
static
constexpr
index_t
kM
=
kMPerTile
;
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
static
constexpr
index_t
kN
=
kNPerTile
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kK
=
kKPerTile
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
};
};
}
// namespace ck_tile
}
// namespace ck_tile
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_BF16
void
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
BF16
,
BF16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F16
,
F16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F8
,
F8
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F32
,
F32
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
I8
,
I8
,
NHWC
,
NHWC
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
DInDataType
,
typename
InLayout
,
typename
OutLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>>
{
using
DeviceOp
=
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
)
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
DOutDataType
,
BF16
>
&&
is_same_v
<
DInDataType
,
BF16
>
)
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F32
>
&&
is_same_v
<
DInDataType
,
F32
>
)
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F8
>
&&
is_same_v
<
DInDataType
,
F8
>
)
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
)
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
op_ptrs
);
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
b924e330
...
@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta
...
@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
...
@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory<
...
@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
op_ptrs
);
}
}
}
#endif
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
b924e330
...
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
AComputeType
,
float
>
&&
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
AComputeType
,
half_t
>
&&
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
op_ptrs
);
}
#endif
}
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
View file @
b924e330
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
View file @
b924e330
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
b924e330
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
b924e330
...
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
...
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
View file @
b924e330
...
@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
...
@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
...
@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
...
@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
...
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
View file @
b924e330
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -23,6 +23,15 @@ void add_device_maxpool_bwd_bf16_instances(
...
@@ -23,6 +23,15 @@ void add_device_maxpool_bwd_bf16_instances(
void
add_device_maxpool_bwd_f32_instances
(
void
add_device_maxpool_bwd_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F32
,
I32
,
F32
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F32
,
I32
,
F32
>>>&
);
#endif
#endif
#ifdef CK_ENABLE_FP8
void
add_device_maxpool_bwd_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F8
,
I32
,
F8
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_maxpool_bwd_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
I8
,
I32
,
I8
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>>
ck
::
tensor_operation
::
device
::
DeviceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>>
...
@@ -32,6 +41,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -32,6 +41,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
...
@@ -47,6 +57,16 @@ struct DeviceOperationInstanceFactory<
...
@@ -47,6 +57,16 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_f32_instances
(
op_ptrs
);
add_device_maxpool_bwd_f32_instances
(
op_ptrs
);
#endif
#endif
#ifdef CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F8
>
&&
is_same_v
<
DInDataType
,
F8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_f8_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_int8_instances
(
op_ptrs
);
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
InOutRank
=
4
;
static
constexpr
auto
WindowRank
=
2
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP16 - return index
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// BF16 - return index
void
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP32 - return index
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
// I8
void
add_device_pool2d_fwd_nhwc_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// I8 - return index
void
add_device_pool2d_fwd_nhwc_index_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
// F8
void
add_device_pool2d_fwd_nhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// F8 - return index
void
add_device_pool2d_fwd_nhwc_index_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>>
{
using
DeviceOp
=
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f32_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
I8
>
&&
is_same_v
<
OutDataType
,
I8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_i8_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_i8_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
InDataType
,
F8
>
&&
is_same_v
<
OutDataType
,
F8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f8_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f8_instances
(
op_ptrs
);
}
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
View file @
b924e330
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef CK_ENABLE_FP16
// FP16
// FP16
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
@@ -36,8 +36,22 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
...
@@ -36,8 +36,22 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_BF16
using
F8
=
ck
::
f8_t
;
// F8
void
add_device_pool3d_fwd_ndhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool3d_fwd_ndhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NDHWC
,
NDHWC
,
AvgOp
,
false
>>>&
);
// FP8 - return index
void
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F8
,
F8
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
// BF16
// BF16
void
add_device_pool3d_fwd_ndhwc_bf16_instances
(
void
add_device_pool3d_fwd_ndhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
@@ -51,8 +65,7 @@ void add_device_pool3d_fwd_ndhwc_bf16_instances(
...
@@ -51,8 +65,7 @@ void add_device_pool3d_fwd_ndhwc_bf16_instances(
void
add_device_pool3d_fwd_ndhwc_index_bf16_instances
(
void
add_device_pool3d_fwd_ndhwc_index_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
// FP32
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
@@ -66,7 +79,21 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
...
@@ -66,7 +79,21 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
#endif
// I8
void
add_device_pool3d_fwd_ndhwc_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool3d_fwd_ndhwc_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NDHWC
,
NDHWC
,
AvgOp
,
false
>>>&
);
// I8 - return index
void
add_device_pool3d_fwd_ndhwc_index_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
I8
,
I8
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
...
@@ -99,7 +126,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -99,7 +126,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
OutLayout
,
NDHWC
>
)
if
constexpr
(
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
OutLayout
,
NDHWC
>
)
{
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
...
@@ -112,8 +138,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -112,8 +138,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
}
}
}
}
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
...
@@ -126,8 +150,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -126,8 +150,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_bf16_instances
(
op_ptrs
);
add_device_pool3d_fwd_ndhwc_bf16_instances
(
op_ptrs
);
}
}
}
}
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
...
@@ -140,7 +162,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -140,7 +162,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
}
}
}
}
#endif
else
if
constexpr
(
is_same_v
<
InDataType
,
F8
>
&&
is_same_v
<
OutDataType
,
F8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool3d_fwd_ndhwc_index_f8_instances
(
op_ptrs
);
}
else
{
add_device_pool3d_fwd_ndhwc_f8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
I8
>
&&
is_same_v
<
OutDataType
,
I8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool3d_fwd_ndhwc_index_i8_instances
(
op_ptrs
);
}
else
{
add_device_pool3d_fwd_ndhwc_i8_instances
(
op_ptrs
);
}
}
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
b924e330
...
@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
set
(
FMHA_FWD_FAST_EXP2 true
)
set
(
FMHA_FWD_FAST_EXP2 true
)
endif
()
endif
()
if
(
FMHA_FWD_FAST_EXP2
)
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND
EXAMPLE_FMHA_FWD
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
list
(
APPEND
FMHA
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
else
()
list
(
APPEND
EXAMPLE_FMHA_FWD
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
list
(
APPEND
FMHA
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
endif
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal
)
target_compile_options
(
device_mha_instance PRIVATE
${
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
}
)
list
(
APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
list
(
APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
target_compile_options
(
device_mha_instance PRIVATE
${
FMHA_COMPILE_OPTIONS
}
)
endif
()
endif
()
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
...
...
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
0 → 100644
View file @
b924e330
set
(
DEVICE_AVGPOOL_2D_BWD_INSTANCES
)
list
(
APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp
)
add_instance_library
(
device_avg_pool2d_bwd_instance
${
DEVICE_AVGPOOL_2D_BWD_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
BF16
,
BF16
,
NHWC
,
NHWC
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_avgpool_2D_bwd_nhwc_instances
<
BF16
,
BF16
,
F32
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment