Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
185af92b
Commit
185af92b
authored
Apr 26, 2023
by
ltqin
Browse files
Merge branch 'develop' into lib_gemm_softmax_gemm_type
parents
5f4a0f73
8bb2bb4a
Changes
77
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1364 additions
and
619 deletions
+1364
-619
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
...tization/run_conv2d_fwd_perlayer_quantization_example.inc
+3
-3
include/ck/ck.hpp
include/ck/ck.hpp
+5
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+39
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+60
-412
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+612
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+48
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+7
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+427
-110
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+24
-34
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+58
-0
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
...ary/tensor_operation_instance/gpu/normalization_swish.hpp
+12
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+17
-17
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+17
-17
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
...n/grouped_convolution_forward_perchannel_quantization.hpp
+11
-11
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
...ion/grouped_convolution_forward_perlayer_quantization.hpp
+11
-11
No files found.
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
View file @
185af92b
...
...
@@ -162,9 +162,9 @@ int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWK
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYX
G
C
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
K
;
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
...
...
include/ck/ck.hpp
View file @
185af92b
...
...
@@ -168,6 +168,11 @@
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#endif
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
185af92b
...
...
@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
DsLayout
::
Size
()
==
DsDataType
::
Size
(),
"wrong! inconsis
i
ten NumDTensor"
);
static_assert
(
DsLayout
::
Size
()
==
DsDataType
::
Size
(),
"wrong! inconsisten
t
NumDTensor"
);
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
0 → 100644
View file @
185af92b
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
185af92b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
0 → 100644
View file @
185af92b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
185af92b
...
...
@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
index_t
block_start_
;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
*
* @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
*
* @tparam MPerBlock Output block tile size in M dimension.
* @tparam NPerBlock Output block tile size in N dimension.
*/
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_3DGrid_KSplit
{
__host__
__device__
BlockToCTileMap_3DGrid_KSplit
()
=
default
;
__host__
__device__
constexpr
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
k_split
)
const
{
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
{
return
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
185af92b
...
...
@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
185af92b
...
...
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if defined(__gfx90a__)
#if
CK_WORKAROUND_DENORM_FIX &&
defined(__gfx90a__)
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
185af92b
...
...
@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatA
)
<=
TwoGB
&&
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatB
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
185af92b
...
...
@@ -265,7 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
#if
CK_WORKAROUND_DENORM_FIX &&
defined(__gfx90a__)
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
using
FloatABAdjusted
=
FloatAB
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
185af92b
...
...
@@ -135,7 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
#if
CK_WORKAROUND_DENORM_FIX &&
defined(__gfx90a__)
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
using
FloatABAdjusted
=
FloatAB
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
185af92b
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
185af92b
...
...
@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
...
...
@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
Empty_Tuple
,
G
NHWK
,
int8_t
,
int8_t
,
NHW
G
K
,
BF16
,
BF16
,
Empty_Tuple
,
int8_t
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
...
...
@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
3
,
...
...
@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
...
...
@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
...
...
@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
// no instance
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
185af92b
...
...
@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
...
...
@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
View file @
185af92b
...
...
@@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
void
add_device_normalization_rank_5_3_swish_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
Swish
,
5
,
3
>>>&
);
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F32
,
F16
,
Swish
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -70,6 +74,14 @@ struct DeviceOperationInstanceFactory<
add_device_normalization_rank_5_3_swish_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
View file @
185af92b
...
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
...
...
@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
185af92b
...
...
@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
I32_Tuple
,
...
...
@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
View file @
185af92b
...
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
F32_Tuple
,
...
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
void
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
F32_Tuple
,
...
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
F32_Tuple
,
...
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
void
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
GK_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
F32_Tuple
,
...
...
@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
View file @
185af92b
...
...
@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
// grouped conv2d forward,
G
NHWC/GKYXC/
G
NHWK
// grouped conv2d forward, NHW
G
C/GKYXC/NHW
G
K
void
add_device_conv2d_dl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
Empty_Tuple
,
...
...
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
void
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
Empty_Tuple
,
...
...
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
Empty_Tuple
,
...
...
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
G
NHWC
,
NHW
G
C
,
GKYXC
,
Empty_Tuple
,
G
NHWK
,
NHW
G
K
,
int8_t
,
int8_t
,
Empty_Tuple
,
...
...
@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
G
NHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
G
NHWK
>
)
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHW
G
C
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHW
G
K
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment