Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f0298581
Commit
f0298581
authored
Oct 18, 2023
by
Harisankar Sadasivan
Browse files
cmakelist changes to exclude navi cards for gemv splitk & merge changes from dev
parent
675aa69e
Changes
129
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
614 additions
and
892 deletions
+614
-892
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+0
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+18
-9
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+54
-23
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+55
-89
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+4
-28
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+210
-628
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+4
-26
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+0
-4
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+135
-17
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-6
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+16
-6
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
...ary/tensor_operation_instance/gpu/convolution_forward.hpp
+8
-7
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+58
-10
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
...bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
+0
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
f0298581
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
f0298581
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return
ds_offset
;
return
ds_offset
;
}
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
f0298581
...
@@ -113,7 +113,6 @@ struct PassThrough
...
@@ -113,7 +113,6 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -143,9 +142,7 @@ struct PassThrough
...
@@ -143,9 +142,7 @@ struct PassThrough
{
{
y
=
type_convert
<
f8_t
>
(
x
);
y
=
type_convert
<
f8_t
>
(
x
);
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
{
...
@@ -175,7 +172,6 @@ struct PassThrough
...
@@ -175,7 +172,6 @@ struct PassThrough
{
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
@@ -221,7 +216,6 @@ struct ConvertF8SR
...
@@ -221,7 +216,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
}
};
};
#endif
struct
Scale
struct
Scale
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f0298581
...
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
return
MakeAGridDescriptor_M_
N
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
return
MakeAGridDescriptor_M_
K
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
},
},
Number
<
NumATensor
>
{});
Number
<
NumATensor
>
{});
}
}
...
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeDataType
,
ComputeDataType
,
// ComputeDataType for A
ComputeDataType
,
// ComputeDataType for B
AccDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
f0298581
...
@@ -36,7 +36,7 @@ __global__ void
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
return
true
;
}
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
f0298581
...
@@ -22,13 +22,19 @@ namespace ck {
...
@@ -22,13 +22,19 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -37,10 +43,13 @@ __global__ void
...
@@ -37,10 +43,13 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
);
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ComputeType
,
ComputeType
,
// ComputeType A
ComputeType
,
ComputeType
,
// ComputeType B
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
f0298581
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto
src_vector_container
=
src_vector_type
{
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_vector_container
.
template
AsType
<
src_elem_op_vec_t
>()[
idx
]);
});
// copy data from src_vector_container into src_thread_scratch_
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src
_vector_t
>(
.
template
SetAsType
<
dst
_vector_t
>(
src_data_idx_seq
,
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src
_vector_t
>()[
I0
]);
op_r_v
.
template
AsType
<
dst
_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
{
...
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
});
#else
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
{
// each transpose does
// each transpose does
...
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
Src
Data
,
SrcScalarPerVector
>
;
using
src_vector_t
=
vector_type_maker_t
<
Dst
Data
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// get DstScalarPerVector # of read-only references to src vectors from
...
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number
<
num_dst_vector
>
{});
Number
<
num_dst_vector
>
{});
// do data transpose
// do data transpose
transpose_vectors
<
Src
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
transpose_vectors
<
Dst
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
src_vector_refs
,
dst_vector_refs
);
});
});
}
}
else
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
{
// apply the src elementwise op and convert to DstData under the hood if needed
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
DstData
dst_v
;
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
src_element_op_
(
dst_v
,
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
dst_thread_scratch_
(
idx
)
=
dst_v
;
}
});
#endif
#endif
}
}
...
@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
SrcThreadScratch
=
SrcData
,
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcScalarPerVector
,
DstData
,
// apply data_convert with SrcThreadScratch
decltype
(
src_thread_scratch_desc_
),
SrcScalarPerVector
,
true
>
;
decltype
(
src_thread_scratch_desc_
),
true
>
;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstData
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
f0298581
...
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number
<
num
>
{});
Number
<
num
>
{});
}
}
template
<
typename
T
>
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
template
<
typename
SrcBuffers
,
...
@@ -159,94 +156,63 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -159,94 +156,63 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid
);
is_src_valid
);
});
});
if
constexpr
(
is_detected
<
has_vec_len
,
decltype
(
element_op_
)
>::
value
)
constexpr
auto
get_elem_op_vec_len
=
[]()
{
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
element_op_
)
>::
value
)
constexpr
auto
elem_op_vec_len
=
decltype
(
element_op_
)
::
vec_len
;
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack8_invocable
)
static_assert
(
is_same
<
remove_cvref_t
<
decltype
(
elem_op_vec_len
)
>
,
index_t
>::
value
,
return
math
::
min
(
8
,
SrcScalarPerVector
);
"vec_len in element_op_ type is not index_t"
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
static_assert
(
elem_op_vec_len
==
1
||
elem_op_vec_len
==
2
||
elem_op_vec_len
==
4
||
elem_op_vec_len
==
8
,
"vec_len in element_op_ must be 1, 2, 4, 8"
);
static_assert
(
SrcScalarPerVector
%
elem_op_vec_len
==
0
,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!"
);
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
else
{
// apply pointwise function
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// pointwise function signature:
// get reference to src data
// element_op_(dst_data_refs[I0],
const
auto
src_data_refs
=
generate_tie
(
// dst_data_refs[I1],
// return type should be lvalue
// ...,
[
&
](
auto
iSrc
)
->
const
auto
&
{
// src_data_refs[I0],
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
// src_data_refs[I1],
// ...)
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
},
});
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
f0298581
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
{
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
{
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
{
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
template
<
typename
base_type
,
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
MPerXdlops
,
...
@@ -792,7 +784,6 @@ struct MfmaSelector
...
@@ -792,7 +784,6 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -804,9 +795,7 @@ struct MfmaSelector
...
@@ -804,9 +795,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
{
...
@@ -818,9 +807,7 @@ struct MfmaSelector
...
@@ -818,9 +807,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
{
...
@@ -832,9 +819,7 @@ struct MfmaSelector
...
@@ -832,9 +819,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
{
...
@@ -846,7 +831,6 @@ struct MfmaSelector
...
@@ -846,7 +831,6 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert
(
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
#if defined CK_ENABLE_FP8
is_same
<
base_type
,
bf8_t
>::
value
||
||
is_same
<
base_type
,
f8_t
>::
value
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
#endif
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
#if defined CK_ENABLE_BF8
||
is_same
<
base_type
,
bf8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
)
#endif
,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
f0298581
...
@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum
...
@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC
=
3
,
GLC_SLC
=
3
,
};
};
template
<
typename
T
,
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
index_t
N
,
__device__
typename
vector_type
<
int8_t
,
N
>::
type
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
amd_buffer_load_impl_raw
(
int32x4_t
src_wave_buffer_resource
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
index_t
src_wave_addr_offset
)
{
{
static_assert
(
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
"wrong! not implemented"
);
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
if
constexpr
(
N
==
1
)
{
{
// use fp32 load to mimic fp64 load
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
if
constexpr
(
N
==
1
)
src_thread_addr_offset
,
{
src_wave_addr_offset
,
const
float2_t
tmp
=
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
bit_cast
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
else
if
constexpr
(
N
==
2
)
{
{
if
constexpr
(
N
==
1
)
{
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
return
bit_cast
<
int8x2_t
>
(
tmp
);
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
else
if
constexpr
(
N
==
4
)
{
{
if
constexpr
(
N
==
1
)
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
return
bit_cast
<
int8x4_t
>
(
tmp
);
}
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
N
==
8
)
{
{
if
constexpr
(
N
==
1
)
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
{
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x2_t
>
(
tmp
);
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
}
}
else
if
constexpr
(
N
==
16
)
else
if
constexpr
(
N
==
4
)
{
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
32
)
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
8
>
tmp
;
return
bit_cast
<
int8x4_t
>
(
tmp
);
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
#endif
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
}
else
if
constexpr
(
N
==
8
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
return
bit_cast
<
int8x32_t
>
(
tmp
);
#endif
}
}
else
if
constexpr
(
N
==
64
)
else
if
constexpr
(
N
==
16
)
{
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
src_thread_addr_offset
,
vector_type
<
int8_t
,
16
>
tmp
;
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp2
=
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int32_t
),
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp3
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
src_thread_addr_offset
,
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_wave_addr_offset
+
12
*
sizeof
(
int32_t
),
src_thread_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
src_wave_addr_offset
+
8
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
vector_type
<
int32_t
,
16
>
tmp
;
#endif
}
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
2
>
{})
=
tmp2
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
3
>
{})
=
tmp3
;
return
bit_cast
<
int8x64_t
>
(
tmp
);
}
}
}
}
template
<
typename
T
,
template
<
typename
T
,
index_t
N
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
int32x4_t
dst_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
dst_thread_addr_offset
,
index_t
src_wave_addr_offset
)
index_t
dst_wave_addr_offset
)
{
{
static_assert
(
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl_raw
(
const
typename
vector_type
<
int8_t
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
N
==
1
)
{
{
// use fp32 store to mimic fp64 store
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
if
constexpr
(
N
==
1
)
dst_wave_buffer_resource
,
{
dst_thread_addr_offset
,
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
float2_t
>
(
src_thread_data
),
dst_wave_addr_offset
,
dst_wave_buffer_resource
,
static_cast
<
index_t
>
(
coherence
));
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
else
if
constexpr
(
N
==
2
)
{
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_fp32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
else
if
constexpr
(
N
==
4
)
{
{
if
constexpr
(
N
==
1
)
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
{
dst_wave_buffer_resource
,
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_thread_addr_offset
,
dst_wave_buffer_resource
,
dst_wave_addr_offset
,
dst_thread_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
else
if
constexpr
(
N
==
8
)
{
{
if
constexpr
(
N
==
1
)
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
{
dst_wave_buffer_resource
,
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_thread_addr_offset
,
dst_wave_buffer_resource
,
dst_wave_addr_offset
,
dst_thread_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
N
==
16
)
{
{
if
constexpr
(
N
==
1
)
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
{
dst_wave_buffer_resource
,
llvm_amdgcn_raw_buffer_store_i32
(
src_thread_data
,
dst_thread_addr_offset
,
dst_wave_buffer_resource
,
dst_wave_addr_offset
,
dst_thread_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
else
if
constexpr
(
N
==
32
)
{
{
if
constexpr
(
N
==
1
)
vector_type
<
int32_t
,
8
>
tmp
{
bit_cast
<
int32x8_t
>
(
src_thread_data
)};
{
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
1
>
{}],
{
dst_wave_buffer_resource
,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
dst_thread_addr_offset
,
llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
dst_wave_buffer_resource
,
static_cast
<
index_t
>
(
coherence
));
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
}
else
if
constexpr
(
N
==
64
)
{
vector_type
<
int32_t
,
16
>
tmp
{
bit_cast
<
int32x16_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
8
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
12
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
);
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -1127,54 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1127,54 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#else
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
#endif
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#endif
#endif
}
}
...
@@ -1232,62 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1232,62 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#else
if
(
dst_thread_element_valid
)
if
(
dst_thread_element_valid
)
{
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
}
}
#endif
#endif
}
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
f0298581
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
struct
intrin_mfma_f32_32x32x16bf8bf8
;
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
struct
intrin_mfma_f32_32x32x16f8bf8
;
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
struct
intrin_mfma_f32_32x32x16bf8f8
;
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
f0298581
...
@@ -9,15 +9,9 @@ namespace ck {
...
@@ -9,15 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
#endif
using
bf8_t
=
unsigned
_BitInt
(
8
);
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_t
>
{
{
using
type
=
bf8_t
;
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_t
>
{
{
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericUtils
struct
NumericUtils
...
@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
...
@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
f8_t
>
{
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
mant
=
3
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_t
>
{
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
mant
=
2
;
};
};
#endif
//
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
f0298581
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
...
@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
}
// namespace ck::utils
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/is_detected.hpp
View file @
f0298581
...
@@ -31,4 +31,13 @@ struct nonesuch
...
@@ -31,4 +31,13 @@ struct nonesuch
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
f0298581
...
@@ -9,8 +9,10 @@
...
@@ -9,8 +9,10 @@
namespace
ck
{
namespace
ck
{
// Convert X to Y
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
...
@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
...
@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
...
@@ -80,11 +95,23 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -80,11 +95,23 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -92,20 +119,33 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...
@@ -92,20 +119,33 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
rng
);
#endif
}
}
// convert fp8 to fp32
// convert fp8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
}
// convert fp16 to fp8
// convert fp16 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -113,22 +153,43 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -113,22 +153,43 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
}
// convert fp8 to fp16
// convert fp8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -136,20 +197,33 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
...
@@ -136,20 +197,33 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
return
utils
::
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#endif
}
}
// convert bf8 to fp32
// convert bf8 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
}
// convert fp16 to bf8
// convert fp16 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -157,16 +231,25 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -157,16 +231,25 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
}
// convert bf8 to fp16
// convert bf8 to fp16
template
<
>
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -229,58 +312,91 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
...
@@ -229,58 +312,91 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
// convert fp32 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
rng
);
#endif
}
}
// convert fp16 to fp8 with stochastic rounding
// convert fp16 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
}
#else
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
// convert fp32 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#endif
}
}
// convert fp16 to bf8 with stochastic rounding
// convert fp16 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
{
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#elif 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
@@ -290,7 +406,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -290,7 +406,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
}
#else
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
f0298581
...
@@ -20,12 +20,8 @@ using F16 = ck::half_t;
...
@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
#endif
#if defined CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
f0298581
...
@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
is_same_v
<
OutLayout
,
NWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
}
#endif
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#ifdef DL_KERNELS
}
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#endif
#endif
#if defined(DL_KERNELS) && defined(CK_ENABLE_FP32)
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
}
#endif
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
View file @
f0298581
...
@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
...
@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_INT8
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
f0298581
...
@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
...
@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
#endif
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -105,7 +125,8 @@ template <typename ADataType,
...
@@ -105,7 +125,8 @@ template <typename ADataType,
typename
CDataType
,
typename
CDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
,
typename
ComputeType
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
...
@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
CDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ComputeType
>>
{
{
using
DeviceOp
=
DeviceGemmSplitK
<
ALayout
,
using
DeviceOp
=
DeviceGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
...
@@ -125,14 +147,15 @@ struct DeviceOperationInstanceFactory<
...
@@ -125,14 +147,15 @@ struct DeviceOperationInstanceFactory<
CDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ComputeType
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
is_same_v
<
CDataType
,
float
>
&&
is_same_v
<
ComputeType
,
float
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -157,8 +180,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -157,8 +180,8 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
C
Data
Type
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
C
ompute
Type
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -183,8 +206,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -183,8 +206,8 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
C
Data
Type
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
C
ompute
Type
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -207,8 +230,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -207,8 +230,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
C
Data
Type
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
C
ompute
Type
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
...
@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
f8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
(
op_ptrs
);
}
}
#endif
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
View file @
f0298581
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment