Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9b1437db
Unverified
Commit
9b1437db
authored
Sep 15, 2023
by
Bartlomiej Wroblewski
Committed by
GitHub
Sep 15, 2023
Browse files
Merge branch 'develop' into bwroblew/dl_fails_vec_size
parents
27a59270
f9d0eddb
Changes
102
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1329 additions
and
182 deletions
+1329
-182
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+18
-28
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+223
-7
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+9
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+123
-8
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+106
-111
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+12
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+94
-10
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+4
-3
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+6
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+28
-4
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
...brary/tensor_operation_instance/gpu/gemm_multiply_add.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
...bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
+86
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp
.../grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+304
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
...y/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
+190
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+94
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
9b1437db
...
...
@@ -29,7 +29,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ABDataType
,
// FIXME: don't assume A/B have same datatype
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
...
...
@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
using
ABDataTypeAdjusted
=
ABDataType
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABData
Type
),
sizeof
(
Compute
Type
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
...
...
@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
A
B
DataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
A
BDataType
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
...
...
@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEElementwiseOperation_
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
A
B
DataType
*
__restrict__
p_a_grid
,
const
A
BDataType
*
__restrict__
p_b_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
...
...
@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
A
B
DataType
,
ABDataTypeAdjusted
,
ADataType
,
ComputeType
,
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
A
BDataType
,
ABDataTypeAdjusted
,
BDataType
,
ComputeType
,
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataTypeAdjusted
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ABDataTypeAdjusted
,
ComputeType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ComputeType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
...
...
@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
A
B
DataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
A
BDataType
*>
(
p_b_grid_
);
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
using
DsGridDesc_M_N
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
9b1437db
...
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
}
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
9b1437db
...
...
@@ -11,8 +11,14 @@ namespace ck {
enum
struct
DppInstr
{
dpp8_f16_16x16x2
=
0
,
dpp8_f16_1x32x2
=
0
,
dpp8_f16_2x16x2
,
dpp8_f16_2x32x2
,
dpp8_f16_4x16x2
,
dpp8_f16_4x32x2
,
dpp8_f16_8x16x2
,
dpp8_f16_8x32x2
,
dpp8_f16_16x16x2
,
dpp8_f16_32x8x2
};
...
...
@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_16x16x2
>
{
...
...
@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_1x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
1
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
{
...
...
@@ -143,6 +329,12 @@ struct DppSelector
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
...
...
@@ -155,6 +347,36 @@ struct DppSelector
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
BaseType
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
...
...
@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized.
static_assert
(
selected_dpp
.
share_a
);
static_assert
(
selected_dpp
.
n_per_thread
==
1
);
static_assert
(
selected_dpp
.
m_per_thread
==
selected_dpp
.
lanegroup_size
);
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
*
selected_dpp
.
lanegroup_size
);
...
...
@@ -215,11 +436,6 @@ struct DppGemm
__host__
__device__
constexpr
DppGemm
()
{
static_assert
(
MPerDpp
==
8
||
MPerDpp
==
16
||
MPerDpp
==
32
,
"MPerDpp must be either 8, 16 or 32."
);
static_assert
(
NPerDpp
==
8
||
NPerDpp
==
16
||
NPerDpp
==
32
,
"NPerDpp must be either 8, 16 or 32."
);
static_assert
(
KPack
%
dpp_instr
.
k_per_dpp
==
0
,
"KPack must be divisible by k_per_dpp."
);
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
9b1437db
...
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
...
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
...
...
@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
...
...
@@ -651,6 +654,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
9b1437db
...
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
#else
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
}
#endif
#endif
}
// buffer_load requires:
...
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
...
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
}
#endif
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
9b1437db
...
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
9b1437db
...
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
using
f8_t
=
uint8_t
;
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
//
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
{
...
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericLimits
<
f8_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
b
it_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
b
f8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
b
it_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
b
f8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
b
it_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
b
f8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
b
it_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
b
f8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
template
<
>
struct
NumericUtils
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
};
#endif
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
9b1437db
...
...
@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -22,53 +23,38 @@ namespace ck::utils {
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
//
check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
//
fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out_mant
=
NumericUtils
<
Y
>::
mant
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// original type exponent/mantissa layout
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
NumericUtils
<
X
>::
nan_mask
;
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
{
head
=
x_bitwise
&
0xFF800000
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
else
if
constexpr
(
is_half
)
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
in
_exp
-
1
))
-
(
1
<<
(
out
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type
_mant
-
f8
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type
_mant
;
drop_mask
=
(
1
<<
(
in
_mant
-
out
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
in
_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type
_mant
))
if
(
mantissa
>=
(
2
<<
in
_mant
))
{
mantissa
>>=
1
;
exponent
++
;
}
mantissa
>>=
(
type
_mant
-
f8
_mant
);
mantissa
>>=
(
in
_mant
-
out
_mant
);
// check negative exponent
if
(
exponent
<=
0
)
...
...
@@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
if
(
clip
)
{
mantissa
=
(
1
<<
f8
_mant
)
-
1
;
mantissa
=
(
1
<<
out
_mant
)
-
1
;
exponent
=
max_exp
;
}
else
...
...
@@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8
_exp
+
f8
_mant
));
mantissa
&=
(
1
<<
f8
_mant
)
-
1
;
return
(
sign
<<
(
f8
_exp
+
f8
_mant
))
|
(
exponent
<<
f8
_mant
)
|
mantissa
;
return
negative_zero_nan
?
0
:
(
sign
<<
(
out
_exp
+
out
_mant
));
mantissa
&=
(
1
<<
out
_mant
)
-
1
;
return
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
constexpr
int
out
_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out
_mant
=
NumericUtils
<
Y
>::
mant
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
constexpr
uint32_t
ifInf
=
0x7F800000
;
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
constexpr
X
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x
>>
(
f8
_exp
+
f8
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8
_mant
;
uint32_t
sign
=
x
>>
(
in
_exp
+
in
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
in
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
in
_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
typ
e
retval
;
(
1
<<
(
out
_exp
-
1
))
-
(
1
<<
(
in
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwis
e
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x
==
nan_code
)
return
f
NaN
;
return
NaN
;
}
else
{
if
(
x
==
nan_code
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
mantissa
<<=
sh
;
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
exponent
+=
1
-
sh
;
exponent
++
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type
_mant
-
f8
_mant
;
mantissa
<<=
out
_mant
-
in
_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
type
_mant
;
mantissa
|=
1
<<
out
_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
type
_exp
+
type
_mant
))
|
(
exponent
<<
type
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
retval
=
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted
to f8
."
);
// check datatype
s
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace ck::utils
#endif
include/ck/utility/inner_product.hpp
View file @
9b1437db
...
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
}
template
<
>
__device__
void
inner_product
<
bhalf_t
,
bhalf_t
,
float
>
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
...
...
include/ck/utility/type_convert.hpp
View file @
9b1437db
...
...
@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
...
...
@@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert fp8 to fp32
...
...
@@ -97,7 +99,7 @@ template <>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
}
// convert fp16 to fp8
...
...
@@ -108,7 +110,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
...
...
@@ -117,8 +120,53 @@ template <>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert bf8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert bf8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
#endif
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
...
...
@@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
...
...
@@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert fp16 to fp8 with stochastic rounding
...
...
@@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
#endif
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
9b1437db
...
...
@@ -20,7 +20,8 @@ template <typename ADataType,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputType
=
ADataType
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AData
Type
v_a
;
BData
Type
v_b
;
Comput
Type
v_a
;
Comput
Type
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
9b1437db
...
...
@@ -17,10 +17,15 @@ namespace instance {
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
#if defined CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
9b1437db
...
...
@@ -23,12 +23,17 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_d
l
_f16_f16_f16_km_kn_mn_irregular_instances
(
void
add_device_gemm_d
pp
_f16_f16_f16_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -38,12 +43,17 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_d
l
_f16_f16_f16_km_nk_mn_irregular_instances
(
void
add_device_gemm_d
pp
_f16_f16_f16_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -53,12 +63,17 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_d
l
_f16_f16_f16_mk_kn_mn_irregular_instances
(
void
add_device_gemm_d
pp
_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -68,12 +83,17 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_d
l
_f16_f16_f16_mk_nk_mn_irregular_instances
(
void
add_device_gemm_d
pp
_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -375,6 +395,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
...
...
@@ -386,6 +407,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
...
...
@@ -398,6 +420,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
...
...
@@ -409,6 +432,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
View file @
9b1437db
...
...
@@ -45,6 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
PassThrough
,
MultiplyAdd
>>>&
);
#if defined CK_ENABLE_FP8
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
...
...
@@ -70,6 +71,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
#endif
// GEMM + Multiply + Add
template
<
typename
ALayout
,
...
...
@@ -131,6 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
D0DataType
,
float
>
&&
is_same_v
<
D1DataType
,
float
>
&&
is_same_v
<
EDataType
,
half_t
>
)
...
...
@@ -150,6 +153,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
9b1437db
...
...
@@ -57,6 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#if defined CK_ENABLE_FP8
void
add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -96,6 +97,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -176,6 +178,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
#if defined CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
@@ -224,6 +227,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
0 → 100644
View file @
9b1437db
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
namespace
ck
::
tensor_layout
::
convolution
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdWeightDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
static
constexpr
auto
ConvBwdWeightFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_dl_f32_instances
=
std
::
tuple
<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
1
,
2
,
1
,
128
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
1
,
8
,
1
>
,
S
<
1
,
16
,
1
,
16
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_dl_f16_instances
=
std
::
tuple
<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
1
,
2
,
1
,
128
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
1
,
8
,
1
>
,
S
<
1
,
16
,
1
,
16
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_dl_bf16_instances
=
std
::
tuple
<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
F32
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
1
,
8
,
1
,
1
,
1
>
,
S
<
1
,
2
,
1
,
128
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
2
,
3
,
1
,
4
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
1
,
8
,
1
>
,
S
<
1
,
16
,
1
,
16
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
4
,
2
,
3
>
,
S
<
1
,
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp
View file @
9b1437db
...
...
@@ -55,8 +55,8 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple<
// ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instances
// TODO: Change to ScalarPerVector = 1 when inner_product<half_t, half_t, float> will be supported
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<
2
,
F16
,
F16
,
DsDatatype
,
F16
,
F32
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
PassThrough
,
PassThrough
,
CDEElementOp
,
ConvSpec
,
GemmMNKPadding
,
8
,
16
,
4
,
2
,
2
,
1
,
2
,
1
,
S
<
4
,
2
>
,
S
<
1
,
1
>
,
S
<
2
,
1
,
2
,
2
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
2
,
1
,
4
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<
2
,
F16
,
F16
,
DsDatatype
,
F16
,
F32
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
PassThrough
,
PassThrough
,
CDEElementOp
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<
2
,
F16
,
F16
,
DsDatatype
,
F16
,
F32
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
PassThrough
,
PassThrough
,
CDEElementOp
,
ConvSpec
,
GemmMNKPadding
,
8
,
16
,
4
,
2
,
1
,
1
,
2
,
1
,
S
<
4
,
2
>
,
S
<
1
,
1
>
,
S
<
2
,
1
,
2
,
1
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
2
,
1
,
4
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<
2
,
F16
,
F16
,
DsDatatype
,
F16
,
F32
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
PassThrough
,
PassThrough
,
CDEElementOp
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<
2
,
F16
,
F16
,
DsDatatype
,
F16
,
F32
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
PassThrough
,
PassThrough
,
CDEElementOp
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
9b1437db
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
0 → 100644
View file @
9b1437db
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// fp16_output
void
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// fp8_inputB
void
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F8
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F8
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// i8_inputB
void
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
I8
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
I8
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
// fp16_output
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
// fp8_input
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
// i8_input
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
9b1437db
...
...
@@ -230,5 +230,99 @@ check_err(const Range& out,
return
res
;
}
#if defined CK_ENABLE_FP8
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
#endif
#if defined CK_ENABLE_BF8
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
#endif
}
// namespace utils
}
// namespace ck
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment