Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3f299c33
Unverified
Commit
3f299c33
authored
Mar 30, 2023
by
Adam Osewski
Committed by
GitHub
Mar 30, 2023
Browse files
Merge branch 'develop' into aosewski/ggemm_dl_instances
parents
507d793a
091570f5
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
836 additions
and
148 deletions
+836
-148
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+27
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+366
-118
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+55
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+18
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+9
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+94
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+92
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
.../device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
+18
-9
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
...pu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
+9
-0
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp
..._conv2d_dl_bias_perchannel_quantization_int8_instance.cpp
+36
-0
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp
...ce_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp
+37
-0
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp
...conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp
+35
-0
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
...e_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
+37
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
3f299c33
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -58,16 +58,16 @@ __global__ void
...
@@ -58,16 +58,16 @@ __global__ void
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__)
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
using
FloatABAdjusted
=
FloatAB
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
Adjusted
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
Adjusted
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
Adjusted
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
FloatAB
Adjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
3f299c33
...
@@ -18,60 +18,23 @@
...
@@ -18,60 +18,23 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
bool
HasMainKBlockLoop
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
)
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
karg
,
static_cast
<
void
*>
(
p_shared
));
p_b_grid
,
p_c_grid
,
static_cast
<
void
*>
(
p_shared_block
),
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
karg
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -79,13 +42,13 @@ template <index_t BlockSize,
...
@@ -79,13 +42,13 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
ALayout
,
typename
AGridDesc_B_K0_M_K1
,
typename
BLayout
,
typename
BGridDesc_B_K0_N_K1
,
typename
CLayout
,
typename
CMNGridDesc
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K0PerBlock
,
...
@@ -126,10 +89,238 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -126,10 +89,238 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
M01
=
1
;
static
constexpr
auto
N01
=
1
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
K0
;
index_t
k_batch
;
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_c_grid
(
p_c_grid_
),
M
(
M_
),
N
(
N_
),
K
(
K_
),
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
K0
(
K0_
),
k_batch
(
k_batch_
)
{
}
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
};
__host__
__device__
static
auto
CalculateGridSize
(
const
Argument
&
karg
)
{
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
karg
.
N
,
NPerBlock
),
math
::
integer_divide_ceil
(
karg
.
M
,
MPerBlock
),
karg
.
k_batch
);
}
// prefer this to be called on host
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
(
M
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
(
N
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
}
__host__
__device__
static
auto
CalculateK0
(
index_t
K
,
index_t
K_Batch
=
1
)
{
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
return
(
K
+
K_t
-
1
)
/
K_t
*
K0PerBlock
;
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K0
=
CalculateK0
(
K
,
K_Batch
);
return
K_Batch
*
K0
*
K1
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
,
index_t
KPad
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
,
index_t
KPad
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
MPad
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -178,45 +369,68 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -178,45 +369,68 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
c_block_size
*
sizeof
(
FloatC
));
c_block_size
*
sizeof
(
FloatC
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
"wrong! K1 need to be known at compile-time"
);
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
{
"Invalid tuning param!"
);
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
return
false
;
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
}
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
return
false
;
}
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
{
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
return
false
;
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
}
return
false
;
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
return
false
;
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
)
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
false
;
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
return
false
;
}
else
{
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
KPad
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
@@ -224,8 +438,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -224,8 +438,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
has_main_k0_block_loop
;
return
has_main_k0_block_loop
;
}
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
C
MN
GridDesc
&
c_m_n_grid_desc
)
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_m_n_grid_desc
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
@@ -242,10 +457,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -242,10 +457,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
C
MN
GridDesc
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
)
const
CGridDesc
&
c_m_n_grid_desc
,
index_t
/* M01 */
,
index_t
/* N01 */
,
index_t
KBatch
)
{
{
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
C
MN
GridDesc
>
(
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc
>
(
c_m_n_grid_desc
,
8
,
KBatch
);
c_m_n_grid_desc
,
8
,
KBatch
);
}
}
...
@@ -262,24 +478,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -262,24 +478,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
}
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
decltype
(
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CMNGridDesc
{}));
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
)
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
MPadded
,
karg
.
NPadded
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -289,26 +506,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -289,26 +506,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
auto
block_work_idx
=
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_
work_idx
[
I1
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_
m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_
work_idx
[
I2
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_
n_id
*
NPerBlock
);
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -444,7 +651,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -444,7 +651,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
...
@@ -647,7 +853,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -647,7 +853,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_
work_idx
[
I1
],
0
,
block_work_idx
[
I2
]
,
0
),
make_multi_index
(
block_
m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
constexpr
auto
mxdlperwave_forward_step
=
...
@@ -716,6 +922,48 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -716,6 +922,48 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
});
});
}
}
}
}
template
<
typename
Layout
>
struct
LStr
{
static
std
::
string
Get
()
{
return
""
;
}
};
template
<
>
struct
LStr
<
ck
::
tensor_layout
::
gemm
::
RowMajor
>
{
static
std
::
string
Get
()
{
return
"R"
;
}
};
template
<
>
struct
LStr
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
>
{
static
std
::
string
Get
()
{
return
"C"
;
}
};
static
std
::
string
GetTypeString
()
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"GemmXdlSplitKCShuffle_"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
"_"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
"_"
<<
"B"
<<
BlockSize
<<
"_"
<<
"Vec"
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
"x"
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
"_"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
K0PerBlock
<<
"x"
<<
K1
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
3f299c33
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/utility/data_type.hpp
View file @
3f299c33
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -1008,6 +1008,60 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
...
@@ -1008,6 +1008,60 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
// convert bfp16 to fp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
// convert fp16 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
// convert int8 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
include/ck/utility/math_v2.hpp
View file @
3f299c33
...
@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
...
@@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cast
<
float
>
(
x
)));
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
std
::
tanh
(
x
);
};
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
);
};
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
...
@@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
)));
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
static
inline
__device__
double
tanh
(
double
x
)
{
return
::
tanh
(
x
);
};
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
3f299c33
...
@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
...
@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
...
@@ -102,6 +103,10 @@ template <typename Activation>
...
@@ -102,6 +103,10 @@ template <typename Activation>
using
Add_Activation_Mul_Clamp
=
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
Activation
>
using
Add_Mul_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Mul_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
Activation
>
template
<
typename
Activation
>
using
Activation_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
Activation
>
;
using
Activation_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
Activation
>
;
...
@@ -109,6 +114,10 @@ template <typename Activation>
...
@@ -109,6 +114,10 @@ template <typename Activation>
using
Add_Activation_Mul2_Clamp
=
using
Add_Activation_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
Activation
>
;
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
Activation
>
;
template
<
typename
Activation
>
using
Add_Mul2_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Mul2_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
struct
DeviceOperationInstanceFactory
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
View file @
3f299c33
...
@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp
<
Relu
>>>>&
Add_Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
instances
);
void
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp
<
Relu
>>>>&
Add_Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
instances
);
void
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
// piecewise activation function
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
InLayout
,
typename
WeiLayout
,
typename
WeiLayout
,
...
@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
};
};
// non-piecewise activation function
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DsLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DsDataType
,
typename
OutDataType
,
typename
Activation
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
Activation
>>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
Activation
>>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
3f299c33
...
@@ -49,6 +49,21 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -49,6 +49,21 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp
<
Relu
>>>>&
Add_Activation_Mul_Clamp
<
Relu
>>>>&
instances
);
instances
);
void
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
@@ -80,6 +95,22 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -80,6 +95,22 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp
<
Relu
>>>>&
Add_Activation_Mul_Clamp
<
Relu
>>>>&
instances
);
instances
);
void
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
// piecewise activation function
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
InLayout
,
typename
WeiLayout
,
typename
WeiLayout
,
...
@@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
};
};
// non-piecewise activation function
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DsLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DsDataType
,
typename
OutDataType
,
typename
Activation
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
Activation
>>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
Activation
>>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
3f299c33
...
@@ -26,7 +26,8 @@ using S = ck::Sequence<Is...>;
...
@@ -26,7 +26,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
using
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
...
@@ -35,14 +36,22 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
...
@@ -35,14 +36,22 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
192
,
64
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
192
,
32
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
View file @
3f299c33
...
@@ -25,6 +25,7 @@ using GNHWK = ck::tensor_layout::convolution::GNHWK;
...
@@ -25,6 +25,7 @@ using GNHWK = ck::tensor_layout::convolution::GNHWK;
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
...
@@ -32,17 +33,25 @@ using I32_Tuple = ck::Tuple<int32_t>;
...
@@ -32,17 +33,25 @@ using I32_Tuple = ck::Tuple<int32_t>;
using
F32_Tuple
=
ck
::
Tuple
<
float
>
;
using
F32_Tuple
=
ck
::
Tuple
<
float
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
int32_t
,
float
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
int32_t
,
float
>
;
// perlayer
using
Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
PassThrough
>
;
using
Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
PassThrough
>
;
using
Relu_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Relu
>
;
using
Relu_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Relu
>
;
// bias + perlayer
using
Add_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
PassThrough
>
;
using
Add_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
PassThrough
>
;
using
Add_Relu_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Relu
>
;
using
Add_Relu_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Relu
>
;
using
Add_Mul_TanH_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Mul_Activation_Mul_Clamp
<
TanH
>
;
// perchannel
using
Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
PassThrough
>
;
using
Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
PassThrough
>
;
using
Relu_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
Relu
>
;
using
Relu_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
Relu
>
;
// bias + perchannel
using
Add_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
PassThrough
>
;
using
Add_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
PassThrough
>
;
using
Add_Relu_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
Relu
>
;
using
Add_Relu_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
Relu
>
;
using
Add_Mul2_TanH_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Mul2_Activation_Mul_Clamp
<
TanH
>
;
static
constexpr
ck
::
index_t
NDimSpatial
=
2
;
static
constexpr
ck
::
index_t
NDimSpatial
=
2
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp
View file @
3f299c33
...
@@ -76,6 +76,42 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -76,6 +76,42 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
ConvFwd1x1S1P0
,
ConvFwd1x1S1P0
,
4
>
{});
4
>
{});
}
}
void
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul2_TanH_Mul_Clamp
>>>&
instances
)
{
// dl
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwdDefault
,
4
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwd1x1P0
,
4
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwd1x1S1P0
,
4
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp
View file @
3f299c33
...
@@ -76,6 +76,43 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -76,6 +76,43 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
ConvFwd1x1S1P0
,
ConvFwd1x1S1P0
,
4
>
{});
4
>
{});
}
}
void
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul_TanH_Mul_Clamp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwdDefault
,
4
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwd1x1P0
,
4
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_dl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwd1x1S1P0
,
4
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp
View file @
3f299c33
...
@@ -74,6 +74,41 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
...
@@ -74,6 +74,41 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
ConvFwd1x1S1P0
,
ConvFwd1x1S1P0
,
8
>
{});
8
>
{});
}
}
void
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul2_TanH_Mul_Clamp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwdDefault
,
8
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwd1x1P0
,
8
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_GK_Tuple
,
I32_F32_Tuple
,
Add_Mul2_TanH_Mul_Clamp
,
ConvFwd1x1S1P0
,
8
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
View file @
3f299c33
...
@@ -76,6 +76,43 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
...
@@ -76,6 +76,43 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
ConvFwd1x1S1P0
,
ConvFwd1x1S1P0
,
8
>
{});
8
>
{});
}
}
void
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Mul_TanH_Mul_Clamp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwdDefault
,
8
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwd1x1P0
,
8
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv2d_xdl_int8_instances
<
GK_Tuple
,
I32_Tuple
,
Add_Mul_TanH_Mul_Clamp
,
ConvFwd1x1S1P0
,
8
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
script/cmake-ck-dev.sh
View file @
3f299c33
...
@@ -8,8 +8,8 @@ MY_PROJECT_SOURCE=$1
...
@@ -8,8 +8,8 @@ MY_PROJECT_SOURCE=$1
cmake
\
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE
-mcumode
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0
-fPIE
-Wno-gnu-line-marker
\
-
mno-wavefrontsize64 -Wno-gnu-line-marker -save-temps=
$PWD
"
\
-
save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx908;gfx90a"
\
-D
GPU_TARGETS
=
"gfx908;gfx90a"
\
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment