Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8820cf9f
Commit
8820cf9f
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Merge branch 'develop' into feature/integrage-karg-simplification-pr
parents
cb46ef7a
4feebedd
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
563 additions
and
106 deletions
+563
-106
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+80
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+2
-1
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+59
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+39
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+2
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+66
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
...brary/tensor_operation_instance/gpu/contraction_scale.hpp
+66
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+24
-34
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+58
-0
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
...ary/tensor_operation_instance/gpu/normalization_swish.hpp
+93
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+17
-17
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+17
-17
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
...n/grouped_convolution_forward_perchannel_quantization.hpp
+11
-11
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
...ion/grouped_convolution_forward_perlayer_quantization.hpp
+11
-11
library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt
...peration_instance/gpu/contraction_bilinear/CMakeLists.txt
+6
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
8820cf9f
...
@@ -42,7 +42,8 @@ __global__ void
...
@@ -42,7 +42,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
8820cf9f
...
@@ -15,26 +15,32 @@
...
@@ -15,26 +15,32 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
));
karg
,
static_cast
<
void
*>
(
p_shared
)
,
b2c_map
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -478,8 +484,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -478,8 +484,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
()
{
return
BlockToCTileMap_3DGrid_KSplit
<
MPerBlock
,
NPerBlock
>
();
}
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
())
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
...
@@ -504,11 +523,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -504,11 +523,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [KBatch, M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
block_work_idx
,
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]);
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
...
@@ -651,6 +680,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -651,6 +680,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
#if 1
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
...
@@ -662,6 +692,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -662,6 +692,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
K1
>
{};
K1
>
{};
#else
auto
blockwise_gemm
=
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
{};
#endif
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -680,6 +724,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -680,6 +724,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
#if 0
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
...
@@ -725,6 +770,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -725,6 +770,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
#else
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVersion
::
v2
,
1
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
))
/
(
K0PerBlock
*
K1
));
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
#endif
// output: register to global memory
// output: register to global memory
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
8820cf9f
...
@@ -46,7 +46,8 @@ __global__ void
...
@@ -46,7 +46,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
8820cf9f
...
@@ -49,7 +49,8 @@ __global__ void
...
@@ -49,7 +49,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
8820cf9f
...
@@ -53,7 +53,8 @@ __global__ void
...
@@ -53,7 +53,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
8820cf9f
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16
,
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x8i8
,
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
};
};
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x16i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
{
...
@@ -524,17 +570,29 @@ struct MfmaSelector
...
@@ -524,17 +570,29 @@ struct MfmaSelector
#endif
#endif
}
}
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
8820cf9f
...
@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
...
@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
template
<
>
struct
intrin_mfma_i32_32x32x16i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x16_i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x32i8
;
template
<
>
struct
intrin_mfma_i32_16x16x32i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f64_16x16x4f64
;
struct
intrin_mfma_f64_16x16x4f64
;
...
@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
{
#ifdef
__gfx90
a
__
#if
def
ined(__gfx90a__) || defined(
__gfx9
4
0__
)
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
#else
...
...
include/ck/utility/data_type.hpp
View file @
8820cf9f
...
@@ -898,6 +898,8 @@ struct vector_type<T, 256>
...
@@ -898,6 +898,8 @@ struct vector_type<T, 256>
}
}
};
};
using
int64_t
=
long
;
// fp64
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
...
...
include/ck/utility/math.hpp
View file @
8820cf9f
...
@@ -168,6 +168,10 @@ __device__ double exp<double>(double x)
...
@@ -168,6 +168,10 @@ __device__ double exp<double>(double x)
return
exp
(
x
);
return
exp
(
x
);
}
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
{
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
8820cf9f
...
@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>;
...
@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F64_Tuple
=
ck
::
Tuple
<
F64
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
I32
,
F32
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
I32
,
F32
>
;
...
@@ -95,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
...
@@ -95,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
template
<
typename
Activation
>
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
View file @
8820cf9f
...
@@ -19,6 +19,7 @@ namespace tensor_operation {
...
@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// float
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
...
@@ -67,6 +68,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
...
@@ -67,6 +68,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
>>>&
instances
);
// double
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
// Contraction + Bilinear
// Contraction + Bilinear
template
<
index_t
NumDimM
,
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
...
@@ -118,6 +168,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -118,6 +168,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
}
}
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
DDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
View file @
8820cf9f
...
@@ -19,6 +19,7 @@ namespace tensor_operation {
...
@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// float
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
...
@@ -67,6 +68,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
...
@@ -67,6 +68,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
>>>&
instances
);
// double
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
// Contraction + Scale
// Contraction + Scale
template
<
index_t
NumDimM
,
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
...
@@ -117,6 +167,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -117,6 +167,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
}
}
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
8820cf9f
...
@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
...
@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GNHWC
,
...
@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
...
@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
BF16
,
int8_t
,
BF16
,
Empty_Tuple
,
Empty_Tuple
,
int8_t
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
NHWGC
,
...
@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
...
@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
3
,
...
@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
...
@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
// no instance
}
}
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
8820cf9f
...
@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
...
@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
0 → 100644
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_normalization_rank_5_3_swish_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
Swish
,
5
,
3
>>>&
);
// FP32
void
add_device_normalization_rank_5_3_swish_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
Swish
,
5
,
3
>>>&
);
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F32
,
F16
,
Swish
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_normalization_rank_5_3_swish_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_normalization_rank_5_3_swish_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
View file @
8820cf9f
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
...
@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
...
@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_GK_Tuple
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
I32_F32_Tuple
,
...
@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
8820cf9f
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
...
@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
...
@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
I32_Tuple
,
I32_Tuple
,
...
@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
View file @
8820cf9f
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
F32_Tuple
,
F32_Tuple
,
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
F32_Tuple
,
F32_Tuple
,
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
F32_Tuple
,
F32_Tuple
,
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
GK_Tuple
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
F32_Tuple
,
F32_Tuple
,
...
@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
View file @
8820cf9f
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
Empty_Tuple
,
Empty_Tuple
,
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
Empty_Tuple
,
Empty_Tuple
,
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
Empty_Tuple
,
Empty_Tuple
,
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
Empty_Tuple
,
Empty_Tuple
,
...
@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt
View file @
8820cf9f
add_instance_library
(
device_contraction_bilinear_instance
add_instance_library
(
device_contraction_bilinear_instance
#float
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
#double
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp
)
)
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment