Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6d9a07d7
Commit
6d9a07d7
authored
Feb 29, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
b30d416c
a776978c
Changes
193
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1729 additions
and
395 deletions
+1729
-395
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+46
-15
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+10
-8
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+29
-39
include/ck/wrapper/operations/gemm.hpp
include/ck/wrapper/operations/gemm.hpp
+75
-23
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+2
-2
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
+28
-19
include/ck/wrapper/utils/kernel_utils.hpp
include/ck/wrapper/utils/kernel_utils.hpp
+14
-0
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+97
-8
include/ck/wrapper/utils/tensor_partition.hpp
include/ck/wrapper/utils/tensor_partition.hpp
+182
-108
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+186
-69
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+6
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp
...ta/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp
+132
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp
...onv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp
+131
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp
...stance/gpu/grouped_convolution_backward_data_bilinear.hpp
+150
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp
...ion_instance/gpu/grouped_convolution_forward_bilinear.hpp
+177
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
...y/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
+48
-1
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
...k/library/tensor_operation_instance/gpu/permute_scale.hpp
+186
-9
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
...ance/gpu/permute_scale/device_permute_scale_instances.hpp
+42
-56
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
+41
-38
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
..._gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
+147
-0
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
6d9a07d7
...
...
@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
...
...
@@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
// apply type convert
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
...
...
include/ck/utility/type_convert.hpp
View file @
6d9a07d7
...
...
@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
...
...
@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
...
@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
...
...
@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
...
@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
union
{
float
fval
;
...
...
include/ck/wrapper/operations/copy.hpp
View file @
6d9a07d7
...
...
@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
Sequence
<
fals
e
>
,
Sequence
<
fals
e
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
Sequence
<
tru
e
>
,
Sequence
<
tru
e
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
...
...
@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
!
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from DynamicBuffer to StaticBuffer
const
auto
src_
dst_slice_origin
=
const
auto
dst_slice_origin
_idxs
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
constexpr
auto
src_vector_tensor_lengths
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
if
constexpr
(
I
==
VectorDim
)
{
return
Number
<
ScalarPerVector
>
{};
}
else
{
return
I1
;
}
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v4r1
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
decltype
(
src_vector_tensor_lengths
),
decltype
(
dim_access_order
)
>
{
src_tensor
.
GetMultiIdxOffsets
()};
auto
transfer
=
ThreadwiseTensorSliceTransfer_v2
<
std
::
remove_const_t
<
typename
SrcTensorType
::
TensorElementType
>
,
std
::
remove_const_t
<
typename
DstTensorType
::
TensorElementType
>
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
I1
,
false
,
false
>
{
in_grid_desc
,
src_tensor
.
GetMultiIdxOffsets
()};
transfer
.
Run
(
in_grid_desc
,
src_dst_slice_origin
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
src_
dst_slice_origin
,
dst_slice_origin
_idxs
,
dst_tensor
.
GetBuffer
());
}
else
...
...
@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
,
typename
ThreadLayoutTuple
>
__device__
void
blockwise_copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
,
[[
maybe_unused
]]
ThreadLayoutTuple
&
thread_layout
)
typename
ThreadShape
,
typename
ThreadUnrolledDesc
>
__device__
void
blockwise_copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
,
[[
maybe_unused
]]
const
Layout
<
ThreadShape
,
ThreadUnrolledDesc
>&
thread_layout
)
{
static_assert
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
);
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
...
...
@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr
auto
tile_lengths_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
thread_layout_seq
=
generate_sequence_v2
(
[](
auto
I
)
{
return
size
(
Thread
LayoutTuple
{}.
At
(
I
)
);
},
Number
<
num_dims
>
{});
constexpr
auto
thread_layout_seq
=
generate_sequence_v2
(
[](
auto
I
)
{
return
size
<
I
>
(
Thread
Shape
{}
);
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
using
ThisThreadBlock
=
ThisThreadBlock
<
size
(
Thread
LayoutTupl
e
{})
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
size
(
Thread
Shap
e
{})
>
;
// Perform copy between DynamicBuffers
auto
transfer
=
ThreadGroupTensorSliceTransfer_v7
<
...
...
include/ck/wrapper/operations/gemm.hpp
View file @
6d9a07d7
...
...
@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
...
...
@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* (MPerBlock, KPerBlock)
or (K0PerBlock, MPerBlock, K1)
layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* (NPerBlock, KPerBlock)
or (K0PerBlock, NPerBlock, K1)
layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template
<
typename
DataType
,
...
...
@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const
BTensorType
&
b_local_tile_tensor
,
CTensorType
&
c_reg_tensor
)
{
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
ATensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
BTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Lds
);
static_assert
(
CTensorType
::
TensorBufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
...
...
@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using
ATileLayout
=
remove_cvref_t
<
decltype
(
layout
(
a_local_tile_tensor
))
>
;
using
BTileLayout
=
remove_cvref_t
<
decltype
(
layout
(
b_local_tile_tensor
))
>
;
static_assert
(
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
typename
BTileLayout
::
LayoutShape
{}.
Size
());
constexpr
bool
is_3d_desc
=
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
I3
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
ATileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
())
>
;
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
BTileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
())
>
;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
...
...
@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
static_assert
(
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
typename
BTileLayout
::
LayoutShape
{}.
Size
());
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
constexpr
bool
is_3d_desc
=
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
I3
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
ATileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
())
>
;
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
BTileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
())
>
;
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
...
...
@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const
auto
partition_desc
=
BlockwiseGemmXdlops
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
layout
(
c_local_tile_tensor
).
GetUnrolledDescriptor
());
const
auto
lower_upper_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
8
>
{});
auto
sliced_desc
=
transform_tensor_descriptor
(
partition_desc
,
make_tuple
(
make_slice_transform
(
partition_shape
.
At
(
Number
<
0
>
{}),
m_thread_data_on_grid_idx
[
I0
],
partition_shape
.
At
(
Number
<
0
>
{})
+
m_thread_data_on_grid_idx
[
I0
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
1
>
{}),
n_thread_data_on_grid_idx
[
I0
],
partition_shape
.
At
(
Number
<
1
>
{})
+
n_thread_data_on_grid_idx
[
I0
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
2
>
{}),
m_thread_data_on_grid_idx
[
I1
],
partition_shape
.
At
(
Number
<
2
>
{})
+
m_thread_data_on_grid_idx
[
I1
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
3
>
{}),
n_thread_data_on_grid_idx
[
I1
],
partition_shape
.
At
(
Number
<
3
>
{})
+
n_thread_data_on_grid_idx
[
I1
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
4
>
{}),
m_thread_data_on_grid_idx
[
I2
],
partition_shape
.
At
(
Number
<
4
>
{})
+
m_thread_data_on_grid_idx
[
I2
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
5
>
{}),
m_thread_data_on_grid_idx
[
I3
],
partition_shape
.
At
(
Number
<
5
>
{})
+
m_thread_data_on_grid_idx
[
I3
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
6
>
{}),
m_thread_data_on_grid_idx
[
I4
],
partition_shape
.
At
(
Number
<
6
>
{})
+
m_thread_data_on_grid_idx
[
I4
]),
make_slice_transform
(
partition_shape
.
At
(
Number
<
7
>
{}),
n_thread_data_on_grid_idx
[
I2
],
partition_shape
.
At
(
Number
<
7
>
{})
+
n_thread_data_on_grid_idx
[
I2
])),
lower_upper_dims
,
lower_upper_dims
);
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
partition
_desc
)
>
(
partition_shape
,
partition
_desc
);
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
sliced
_desc
)
>
(
partition_shape
,
sliced
_desc
);
auto
partition_tensor
=
make_tensor
<
CTensorType
::
TensorBufferAddressSpace
>
(
c_local_tile_tensor
.
GetPointer
(),
partition_layout
);
partition_tensor
.
SetMultiIdxOffset
(
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]));
return
partition_tensor
;
}
...
...
@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
static_assert
(
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
typename
BTileLayout
::
LayoutShape
{}.
Size
());
constexpr
bool
is_integer
=
is_same_v
<
DataType
,
int8_t
>
||
is_same_v
<
DataType
,
int16_t
>
||
is_same_v
<
DataType
,
int32_t
>
;
using
GemmAccDataType
=
std
::
conditional_t
<
is_integer
,
int32_t
,
float
>
;
constexpr
bool
is_3d_desc
=
typename
ATileLayout
::
LayoutShape
{}.
Size
()
==
I3
;
using
ABlockDesc_K0_M_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
ATileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
ATileLayout
>
())
>
;
using
BBlockDesc_K0_N_K1_Type
=
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
());
conditional_t
<
is_3d_desc
,
typename
BTileLayout
::
LayoutUnrolledDescriptorType
,
decltype
(
detail
::
GetBlockDescriptor
<
GemmTraits
::
K1
,
BTileLayout
>
())
>
;
using
BlockwiseGemmXdlops
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
...
...
@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const
auto
vgpr_layout
=
Layout
<
remove_reference_t
<
decltype
(
vgpr_shape
)
>
,
decltype
(
vgpr_desc
)
>
(
vgpr_shape
,
vgpr_desc
);
// Get vector type for Vgpr
using
BlockwiseGemmCThreadBufferType
=
remove_reference_t
<
decltype
(
BlockwiseGemmXdlops
{}.
GetCThreadBuffer
())
>
;
using
VgprVectorType
=
typename
BlockwiseGemmCThreadBufferType
::
V
;
constexpr
index_t
ScalarPerVector
=
BlockwiseGemmXdlops
::
xdlops_gemm
.
GetRegSizePerXdlops
();
using
VgprVectorType
=
typename
vector_type
<
GemmAccDataType
,
ScalarPerVector
>::
type
;
return
ck
::
wrapper
::
make_register_tensor
<
ck
::
wrapper
::
MemoryTypeEnum
::
Vgpr
,
VgprVectorType
>
(
vgpr_layout
);
}
...
...
include/ck/wrapper/tensor.hpp
View file @
6d9a07d7
...
...
@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
}
}
template
<
typename
...
Ts
,
typename
Shape
,
typename
Flatten
Descriptor
>
template
<
typename
...
Ts
,
typename
Shape
,
typename
Unrolled
Descriptor
>
__host__
__device__
constexpr
auto
GenerateSlicedDescriptor
(
const
Tuple
<
Ts
...
>&
idx
,
const
Shape
&
shape
,
const
Flatten
Descriptor
&
flatten_desc
)
const
Unrolled
Descriptor
&
flatten_desc
)
{
constexpr
auto
old_shape_dims
=
decltype
(
UnrollNestedTuple
(
shape
))
::
Size
();
...
...
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
View file @
6d9a07d7
...
...
@@ -20,48 +20,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template
<
index_t
MPerXDLValue
,
index_t
NPerXDLValue
,
index_t
MXdlPerWaveValue
,
index_t
NXdlPerWaveValue
,
index_t
K1Value
>
template
<
typename
MPerXDLValue
,
typename
NPerXDLValue
,
typename
MXdlPerWaveValue
,
typename
NXdlPerWaveValue
,
typename
K1Value
>
struct
BlockwisGemmXdlTraits
{
static
constexpr
index_t
MPerXDL
=
MPerXDLValue
;
static
constexpr
index_t
NPerXDL
=
NPerXDLValue
;
static
constexpr
index_t
MXdlPerWave
=
MXdlPerWaveValue
;
static
constexpr
index_t
NXdlPerWave
=
NXdlPerWaveValue
;
static
constexpr
index_t
K1
=
K1Value
;
static
constexpr
auto
MPerXDL
=
MPerXDLValue
{}
;
static
constexpr
auto
NPerXDL
=
NPerXDLValue
{}
;
static
constexpr
auto
MXdlPerWave
=
MXdlPerWaveValue
{}
;
static
constexpr
auto
NXdlPerWave
=
NXdlPerWaveValue
{}
;
static
constexpr
auto
K1
=
K1Value
{}
;
};
// K1 = 4
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
4
>
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
4
>
,
Number
<
2
>
,
Number
<
4
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
4
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
4
>
,
Number
<
4
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
4
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
2
>
,
Number
<
4
>>
{
};
// K1 = 8
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
8
>
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
4
>
,
Number
<
2
>
,
Number
<
8
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
8
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
4
>
,
Number
<
8
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
8
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
2
>
,
Number
<
8
>>
{
};
// K1 = 16
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
4
,
2
,
16
>
struct
BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
4
>
,
Number
<
2
>
,
Number
<
16
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
4
,
16
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
4
>
,
Number
<
16
>>
{
};
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
32
,
32
,
2
,
2
,
16
>
struct
BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
:
BlockwisGemmXdlTraits
<
Number
<
32
>
,
Number
<
32
>
,
Number
<
2
>
,
Number
<
2
>
,
Number
<
16
>>
{
};
...
...
include/ck/wrapper/utils/kernel_utils.hpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
namespace
wrapper
{
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/layout_utils.hpp
View file @
6d9a07d7
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
wrapper
{
...
...
@@ -29,6 +30,7 @@ template <typename T>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
namespace
{
namespace
detail
{
/**
* \brief Generate packed (column-major) strides if not passed
*
...
...
@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
}
}
// namespace detail
}
// namespace
/// @endcond
...
...
@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
strides
));
using
UnrolledDescriptorType
=
decltype
(
detail
::
MakeUnrolledDescriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
detail
::
MakeUnrolledDescriptor
(
shape
,
strides
));
}
/**
...
...
@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template
<
typename
Shape
>
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
)
{
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
Tuple
<>
{}));
using
UnrolledDescriptorType
=
decltype
(
detail
::
MakeUnrolledDescriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
detail
::
MakeUnrolledDescriptor
(
shape
,
Tuple
<>
{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
...
...
@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Flatten
Desc
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Flatten
Desc
>&
layout
)
template
<
index_t
idx
,
typename
Shape
,
typename
Unrolled
Desc
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Unrolled
Desc
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
const
auto
new_shape
=
get
<
idx
>
(
shape
);
...
...
@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return
layout
.
GetShape
();
}
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template
<
typename
Shape
,
typename
UnrolledDesc
,
typename
TileLengths
>
__host__
__device__
constexpr
auto
pad
(
const
Layout
<
Shape
,
UnrolledDesc
>&
layout
,
const
TileLengths
&
tile_lengths
)
{
auto
&
unrolled_desc
=
layout
.
GetUnrolledDescriptor
();
// Generate sequence with ones to mark that all dims will be padded
constexpr
auto
do_pads_seq
=
generate_sequence_v2
([](
auto
)
{
return
Number
<
1
>
{};
},
Number
<
Shape
::
Size
()
>
{});
// Create descriptor with padding
auto
padded_desc
=
tensor_operation
::
device
::
PadTensorDescriptor
(
unrolled_desc
,
tile_lengths
,
do_pads_seq
);
// Generate padded shape
const
auto
padded_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
return
padded_desc
.
GetLength
(
Number
<
i
>
{});
},
Number
<
TileLengths
::
Size
()
>
{});
// Create layout
return
Layout
<
decltype
(
padded_shape
),
decltype
(
padded_desc
)
>
(
padded_shape
,
padded_desc
);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template
<
index_t
Idx
,
typename
Shape
,
typename
UnrolledDesc
,
typename
NewLengths
,
typename
NewIdxs
>
__host__
__device__
constexpr
auto
unmerge
(
const
Layout
<
Shape
,
UnrolledDesc
>&
layout
,
const
NewLengths
&
new_lengths
,
[[
maybe_unused
]]
const
NewIdxs
&
new_indexes
)
{
const
auto
&
layout_shape
=
shape
(
layout
);
auto
&
unrolled_desc
=
layout
.
GetUnrolledDescriptor
();
constexpr
auto
dims
=
Shape
::
Size
();
// Generate transforms
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
Idx
)
{
return
make_unmerge_transform
(
new_lengths
);
}
else
{
return
make_pass_through_transform
(
layout_shape
.
At
(
i
));
}
},
Number
<
dims
>
{});
constexpr
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
dims
>
{});
constexpr
auto
upper_dims
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
NewIdxs
>>::
value
)
{
constexpr
auto
idxs_tuple
=
tuple_element_t
<
i
.
value
,
NewIdxs
>
{};
return
to_sequence
(
idxs_tuple
);
}
else
{
constexpr
index_t
index
=
tuple_element_t
<
i
.
value
,
NewIdxs
>
{};
return
Sequence
<
index
>
{};
}
},
Number
<
dims
>
{});
const
auto
unmerged_desc
=
transform_tensor_descriptor
(
unrolled_desc
,
transforms
,
lower_dims
,
upper_dims
);
const
auto
unmerged_shape
=
generate_tuple
([
&
](
auto
i
)
{
return
unmerged_desc
.
GetLength
(
Number
<
i
>
{});
},
Number
<
decltype
(
unmerged_desc
)
::
GetNumOfVisibleDimension
()
>
{});
// Create layout
return
Layout
<
decltype
(
unmerged_shape
),
decltype
(
unmerged_desc
)
>
(
unmerged_shape
,
unmerged_desc
);
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/tensor_partition.hpp
View file @
6d9a07d7
...
...
@@ -6,7 +6,6 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
...
...
@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Multi index after projection.
*/
template
<
typename
MultiIndex
,
typename
ProjectionTuple
>
...
...
@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
}
else
{
return
base_tuple
.
At
(
i_num
);
return
make_tuple
(
base_tuple
.
At
(
i_num
)
)
;
}
},
Number
<
MultiIndex
::
Size
()
>
{});
...
...
@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Shape with dims from projection
*/
template
<
typename
...
Ts
,
typename
...
Ps
>
...
...
@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template
<
typename
...
Ts
,
typename
...
Ls
,
typename
...
Ps
>
__host__
__device__
constexpr
auto
CalculateGridSize
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
tile_shape
,
const
Tuple
<
Ps
...
>&
projection
)
const
Tuple
<
Ls
...
>&
tile_shape
)
{
auto
shape_with_projection
=
CalculateShapeWithProjection
(
shape
,
projection
);
return
generate_tuple
(
[
&
](
auto
i
)
{
return
ck
::
math
::
integer_divide_ceil
(
size
<
i
>
(
shape_with_projection
),
size
<
i
>
(
tile_shape
));
},
[
&
](
auto
i
)
{
return
ck
::
math
::
integer_divide_ceil
(
size
<
i
>
(
shape
),
size
<
i
>
(
tile_shape
));
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
...
...
@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return
thread_idxs
*
partition_lengths_seq
+
old_offset_idxs
;
}
/**
* \brief Select dims to partition (skip if slice).
*
* \param block_idxs Input block indexes.
* \return Partitioned dims.
*/
template
<
typename
BlockIdxs
>
__host__
__device__
constexpr
auto
GetDimsToPartition
([[
maybe_unused
]]
const
BlockIdxs
&
block_idxs
)
{
const
auto
dims_to_partition
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
!
is_detected
<
is_slice
,
tuple_element_t
<
i
,
BlockIdxs
>>::
value
)
{
return
Number
<
i
>
{};
}
else
{
return
Tuple
<>
{};
}
},
Number
<
BlockIdxs
::
Size
()
>
{});
// Remove empty tuples
return
UnrollNestedTuple
<
0
,
1
>
(
dims_to_partition
);
}
/**
* \brief Replace slices with zeros (Slice dims are not partitioned).
*
* \param block_idxs Input block indexes.
* \return Parsed dims.
*/
template
<
typename
BlockIdxs
>
__host__
__device__
constexpr
auto
ReplaceSlicesWithZeros
(
const
BlockIdxs
&
block_idxs
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
!
is_detected
<
is_slice
,
tuple_element_t
<
i
,
BlockIdxs
>>::
value
)
{
return
block_idxs
.
At
(
i
);
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
BlockIdxs
::
Size
()
>
{});
}
/**
* \brief Calculate default projection.
*
...
...
@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
return
generate_tuple
([
&
](
auto
)
{
return
Number
<
1
>
{};
},
Number
<
TileShape
::
Size
()
>
{});
}
/**
* \brief Calculate thread multi index from 1d thread index.
*
* \param thread_layout Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Multi index.
*/
template
<
typename
ThreadShape
,
typename
ThreadUnrolledDesc
>
__host__
__device__
constexpr
auto
CalculateThreadMultiIdx
(
[[
maybe_unused
]]
const
Layout
<
ThreadShape
,
ThreadUnrolledDesc
>&
thread_layout
,
const
index_t
thread_id
)
{
static_assert
(
ThreadUnrolledDesc
::
GetNumOfTransform
()
==
1
,
"Thread layout should not be transformed."
);
constexpr
auto
embed_transform
=
ThreadUnrolledDesc
{}.
GetTransforms
().
At
(
Number
<
0
>
{});
constexpr
auto
shape
=
ThreadShape
{};
constexpr
auto
strides
=
embed_transform
.
coefficients_
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
return
(
thread_id
/
strides
.
At
(
num_i
))
%
shape
.
At
(
num_i
);
},
Number
<
ThreadShape
::
Size
()
>
{});
}
}
// namespace detail
}
// namespace
...
...
@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_l
engths
Layout of threads (could not be
nest
ed).
* \param thread_l
ayout
Layout of threads (could not be
transform
ed).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
,
typename
ProjectionTuple
>
template
<
typename
TensorType
,
typename
ThreadShape
,
typename
ThreadUnrolledDesc
,
typename
ProjectionTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
[[
maybe_unused
]]
const
ThreadLengthsTuple
&
thread_l
engths
,
[[
maybe_unused
]]
const
Layout
<
ThreadShape
,
ThreadUnrolledDesc
>
&
thread_l
ayout
,
const
index_t
thread_id
,
const
ProjectionTuple
&
projection
)
{
static_assert
(
!
IsNestedTuple
(
Thread
LengthsTupl
e
{}));
static_assert
(
!
IsNestedTuple
(
Thread
Shap
e
{}));
// Calculate new partition shape
const
auto
&
tensor_shape
=
shape
(
tensor
);
// Calculate projected thread lengths
constexpr
auto
projected_thread_lengths
=
detail
::
ApplyProjection
(
Thread
LengthsTupl
e
{},
ProjectionTuple
{});
detail
::
ApplyProjection
(
Thread
Shap
e
{},
ProjectionTuple
{});
constexpr
auto
partition_shape
=
detail
::
CalculateLocalPartitionShape
(
decltype
(
tensor_shape
){},
projected_thread_lengths
);
// Create Thread Cluster Descriptor
constexpr
auto
partition_shape_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
partition_shape
);
},
Number
<
decltype
(
partition_shape
)
::
Size
()
>
{});
constexpr
auto
thread_lengths_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
ThreadLengthsTuple
{});
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
thread_lengths_seq
);
// Calculate thread idxs and offsets
const
auto
thread_idxs
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
)
)
;
const
auto
thread_idxs
=
detail
::
CalculateThreadMultiIdx
(
thread_layout
,
thread_id
);
// Apply projection on thread idxs to remove not needed idxs
const
auto
projected_thread_idxs
=
detail
::
ApplyProjection
(
thread_idxs
,
projection
);
const
auto
offset_multi_idxs
=
detail
::
CalculateOffsetMultiIdxs
(
projected_thread_idxs
,
partition_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
auto
&
unrolled_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
// Slice descriptor
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_slice_transform
(
partition_shape
.
At
(
i
),
offset_multi_idxs
.
At
(
i
),
partition_shape
.
At
(
i
)
+
offset_multi_idxs
.
At
(
i
));
},
Number
<
remove_reference_t
<
decltype
(
tensor_shape
)
>::
Size
()
>
{});
const
auto
lower_upper_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
remove_reference_t
<
decltype
(
tensor_shape
)
>::
Size
()
>
{});
auto
sliced_desc
=
transform_tensor_descriptor
(
unrolled_desc
,
transforms
,
lower_upper_dims
,
lower_upper_dims
);
// Create layout
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
unroll
ed_desc
)
>
(
partition_shape
,
unroll
ed_desc
);
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
slic
ed_desc
)
>
(
partition_shape
,
slic
ed_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
partition_layout
);
// Apply offsets
partition_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
partition_tensor
;
}
...
...
@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
const
ThreadLengthsTuple
&
thread_lengths
,
const
index_t
thread_id
)
template
<
typename
TensorType
,
typename
ThreadShape
,
typename
ThreadUnrolledDesc
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
const
Layout
<
ThreadShape
,
ThreadUnrolledDesc
>&
thread_lengths
,
const
index_t
thread_id
)
{
const
auto
projection
=
detail
::
GenerateDefaultProjection
(
Thread
LengthsTupl
e
{});
const
auto
projection
=
detail
::
GenerateDefaultProjection
(
Thread
Shap
e
{});
return
make_local_partition
(
tensor
,
thread_lengths
,
thread_id
,
projection
);
}
...
...
@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
,
typename
ProjectionTuple
>
template
<
typename
TensorType
,
typename
BlockShapeTuple
,
typename
BlockIdxs
,
typename
ProjectionTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
,
const
BlockIdxs
&
block_id
xs
,
const
ProjectionTuple
&
projection
)
{
static_assert
(
!
IsNestedTuple
(
BlockShapeTuple
{}));
constexpr
bool
is_default_projection
=
is_same_v
<
ProjectionTuple
,
decltype
(
detail
::
GenerateDefaultProjection
(
BlockShapeTuple
{}))
>
;
static_assert
(
!
IsNestedTuple
(
BlockIdxs
{}));
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
auto
&
aligned_desc
=
layout
(
tensor
).
GetMergedNestingDescriptor
();
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if
constexpr
(
BlockShapeTuple
::
Size
()
==
I2
&&
is_default_projection
)
constexpr
auto
projected_tile_shape
=
detail
::
ApplyProjection
(
BlockShapeTuple
{},
ProjectionTuple
{});
// Number of dims which are partitioned
constexpr
auto
dims_to_partition
=
detail
::
GetDimsToPartition
(
BlockIdxs
{});
const
auto
parsed_block_idxs
=
detail
::
ReplaceSlicesWithZeros
(
block_idxs
);
if
constexpr
(
decltype
(
dims_to_partition
)
::
Size
()
==
I2
)
{
// Optimized version for 2d tile shape [MxK]
const
auto
shape_with_projection_dims
=
detail
::
CalculateShapeWithProjection
(
shape
(
tensor
),
projection
);
// Set Value for M, N partition
const
auto
M
=
shape_with_projection_dims
.
At
(
dims_to_partition
.
At
(
I0
));
const
auto
N
=
shape_with_projection_dims
.
At
(
dims_to_partition
.
At
(
I1
));
constexpr
auto
MPerBlock
=
BlockShapeTuple
{}.
At
(
dims_to_partition
.
At
(
I0
));
constexpr
auto
NPerBlock
=
BlockShapeTuple
{}.
At
(
dims_to_partition
.
At
(
I1
));
auto
m_n_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
// Get 1D block id
const
auto
grid_size
=
detail
::
CalculateGridSize
(
shape_with_projection_dims
,
tile_shape
);
const
auto
block_lengths_desc
=
make_naive_tensor_descriptor_packed
(
grid_size
);
const
index_t
block_id_1d
=
block_lengths_desc
.
CalculateOffset
(
parsed_block_idxs
);
// Optimized version for 2d tile shape [MxN]
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
Block
ShapeTuple
{}.
At
(
I0
)
,
Block
ShapeTuple
{}.
At
(
I1
)
,
remove_cvref_t
<
decltype
(
aligned
_desc
)
>>
(
aligned
_desc
);
BlockToCTileMap_M00_N0_M01Adapt
<
MPer
Block
,
NPer
Block
,
remove_cvref_t
<
decltype
(
m_n
_desc
)
>>
(
m_n
_desc
);
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_id
_1d
));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
size
<
0
>
(
tile_shape
));
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
size
<
1
>
(
tile_shape
));
const
auto
offset_multi_idxs
=
make_tuple
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// Apply 0 for non partitioned dims
const
auto
offset_multi_idxs
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
dims_to_partition
.
At
(
I0
))
{
return
m_block_data_idx_on_grid
;
}
else
if
constexpr
(
i
==
dims_to_partition
.
At
(
I1
))
{
return
n_block_data_idx_on_grid
;
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
BlockShapeTuple
::
Size
()
>
{});
const
auto
projected_offset_multi_idxs
=
detail
::
ApplyProjection
(
offset_multi_idxs
,
projection
);
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
Layout
<
remove_reference_t
<
decltype
(
projected_
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
projected_tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
projected_
offset_multi_idxs
));
return
tile_tensor
;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr
auto
projected_tile_shape
=
detail
::
ApplyProjection
(
BlockShapeTuple
{},
ProjectionTuple
{});
using
ProjectedTileShapeTuple
=
decltype
(
projected_tile_shape
);
constexpr
auto
projected_tile_shape_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
ProjectedTileShapeTuple
{}.
At
(
I
);
},
Number
<
ProjectedTileShapeTuple
::
Size
()
>
{});
// Tuple with number of blocks
const
auto
block_lengths
=
detail
::
CalculateGridSize
(
shape
(
tensor
),
tile_shape
,
projection
);
const
auto
block_cluster_desc_
=
make_cluster_descriptor
(
block_lengths
);
const
auto
block_idxs
=
block_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
auto
projected_block_idxs
=
detail
::
ApplyProjection
(
block_idxs
,
projection
);
const
auto
offset_multi_idxs
=
detail
::
CalculateOffsetMultiIdxs
(
const
auto
projected_block_idxs
=
to_multi_index
(
detail
::
ApplyProjection
(
parsed_block_idxs
,
projection
));
const
auto
offset_multi_idxs
=
detail
::
CalculateOffsetMultiIdxs
(
projected_block_idxs
,
projected_tile_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
const
auto
tile_layout
=
...
...
@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
)
template
<
typename
TensorType
,
typename
BlockShapeTuple
,
typename
BlockIdxs
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
BlockIdxs
&
block_idxs
)
{
const
auto
projection
=
detail
::
GenerateDefaultProjection
(
BlockShapeTuple
{});
return
make_local_tile
(
tensor
,
tile_shape
,
block_id
,
projection
);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template
<
typename
TensorType
,
typename
TileLengths
>
__host__
__device__
constexpr
auto
pad
(
const
TensorType
&
tensor
,
const
TileLengths
&
tile_lengths
)
{
const
auto
&
tensor_shape
=
shape
(
tensor
);
using
TensorShapeType
=
remove_reference_t
<
decltype
(
tensor_shape
)
>
;
auto
&
unrolled_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
// Generate sequence with ones to mark that all dims will be padded
constexpr
auto
do_pads_seq
=
generate_sequence_v2
([](
auto
)
{
return
Number
<
1
>
{};
},
Number
<
TensorShapeType
::
Size
()
>
{});
// Create descriptor with padding
auto
padded_desc
=
tensor_operation
::
device
::
PadTensorDescriptor
(
unrolled_desc
,
tile_lengths
,
do_pads_seq
);
// Generate padded shape
const
auto
padded_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
const
auto
&
dim
=
size
<
i
>
(
tensor_shape
);
const
auto
&
tile_length
=
size
<
i
>
(
tile_lengths
);
return
ck
::
math
::
integer_divide_ceil
(
dim
,
tile_length
)
*
tile_length
;
},
Number
<
TileLengths
::
Size
()
>
{});
// Create layout and tensor
const
auto
padded_layout
=
Layout
<
decltype
(
padded_shape
),
decltype
(
padded_desc
)
>
(
padded_shape
,
padded_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
padded_layout
);
partition_tensor
.
SetMultiIdxOffset
(
tensor
.
GetMultiIdxOffsets
());
return
partition_tensor
;
return
make_local_tile
(
tensor
,
tile_shape
,
block_idxs
,
projection
);
}
}
// namespace wrapper
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -25,25 +25,35 @@ template <ck::index_t NDimSpatial,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumAElementwiseTensor
=
0
,
ck
::
index_t
NumBElementwiseTensor
=
0
,
ck
::
index_t
NumDElementwiseTensor
=
0
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
Argument
(
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors
,
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors
,
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors
)
:
input_
{
input
},
weight_
{
weight
},
output_
{
output
},
elementwise_a_tensors_
{
elementwise_a_tensors
},
elementwise_b_tensors_
{
elementwise_b_tensors
},
elementwise_d_tensors_
{
elementwise_d_tensors
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
...
...
@@ -58,6 +68,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const
Tensor
<
WeiDataType
>&
weight_
;
const
Tensor
<
OutDataType
>&
output_
;
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors_
;
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
...
...
@@ -106,26 +120,46 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_out
*
v_wei
;
OutDataType
v_out
;
WeiDataType
v_wei
;
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_out
,
arg
.
output_
(
g
,
n
,
k
,
wo
),
g
,
n
,
k
,
wo
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
arg
.
weight_
(
g
,
k
,
c
,
x
),
g
,
k
,
c
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_out
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
}
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
InDataType
v_acc_converted
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
InDataType
&
v_in
=
arg
.
input_
(
g
,
n
,
c
,
wi
);
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_in
,
v_acc_converted
,
g
,
n
,
c
,
wi
);
};
make_ParallelTensorFunctor
(
f_ncw
,
...
...
@@ -175,20 +209,34 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
float
v_out
=
0
;
float
v_wei
=
0
;
OutDataType
v_out
;
WeiDataType
v_wei
;
arg
.
out_element_op_
(
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)));
arg
.
wei_element_op_
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
),
g
,
n
,
k
,
ho
,
wo
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
),
g
,
k
,
c
,
y
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_out
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
...
...
@@ -197,11 +245,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
InDataType
v_acc_converted
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
InDataType
&
v_in
=
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
);
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_in
,
v_acc_converted
,
g
,
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_nchw
,
...
...
@@ -270,20 +325,37 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
float
v_out
=
0
;
float
v_wei
=
0
;
OutDataType
v_out
;
WeiDataType
v_wei
;
arg
.
out_element_op_
(
ExecuteElementwiseOp
(
arg
.
out_element_op_
,
arg
.
elementwise_a_tensors_
,
Number
<
NumAElementwiseTensor
>
{},
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
wei_element_op_
(
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
),
g
,
n
,
k
,
do_
,
ho
,
wo
);
ExecuteElementwiseOp
(
arg
.
wei_element_op_
,
arg
.
elementwise_b_tensors_
,
Number
<
NumBElementwiseTensor
>
{},
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
),
g
,
k
,
c
,
z
,
y
,
x
);
v_acc
+=
ck
::
type_convert
<
float
>
(
v_out
)
*
ck
::
type_convert
<
float
>
(
v_wei
);
}
}
}
...
...
@@ -295,11 +367,19 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
InDataType
v_acc_converted
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
InDataType
&
v_in
=
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
);
ExecuteElementwiseOp
(
arg
.
in_element_op_
,
arg
.
elementwise_d_tensors_
,
Number
<
NumDElementwiseTensor
>
{},
v_in
,
v_acc_converted
,
g
,
n
,
c
,
di
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
...
...
@@ -325,6 +405,36 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
};
template
<
typename
...
Args
,
typename
ElementwiseOp
,
typename
ElementwiseTensor
,
typename
NumTensor
,
typename
T
>
static
void
ExecuteElementwiseOp
(
ElementwiseOp
&
elementwise_op
,
ElementwiseTensor
&
elementwise_tensors
,
NumTensor
,
T
&
y
,
const
T
&
x
,
Args
...
dims
)
{
if
constexpr
(
NumTensor
::
value
==
0
)
{
elementwise_op
(
y
,
x
);
}
else
if
constexpr
(
NumTensor
::
value
==
1
)
{
elementwise_op
(
y
,
x
,
elementwise_tensors
[
0
](
dims
...));
}
else
if
constexpr
(
NumTensor
::
value
==
2
)
{
elementwise_op
(
y
,
x
,
elementwise_tensors
[
0
](
dims
...),
elementwise_tensors
[
1
](
dims
...));
}
else
{
throw
std
::
runtime_error
(
"ElementOp not supported in reference."
);
}
}
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
...
...
@@ -333,16 +443,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
std
::
array
<
Tensor
<
InDataType
>
,
NumAElementwiseTensor
>&
elementwise_a_tensors
=
{},
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors
=
{},
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors
=
{})
{
return
Argument
{
input
,
weight
,
...
...
@@ -353,7 +467,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
out_element_op
,
elementwise_a_tensors
,
elementwise_b_tensors
,
elementwise_d_tensors
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
6d9a07d7
...
...
@@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
...
...
@@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
BF8
=
ck
::
bf8_t
;
using
F8
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
static
constexpr
auto
ConvBwdDataDefault
=
ConvolutionBackwardDataSpecialization
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
;
// f16_f16_f32_f16
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionBackwardDataSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_data_xdl_bilinear_f16_instances
=
std
::
tuple
<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
// bf16_bf16_f32_bf16
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionBackwardDataSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances
=
std
::
tuple
<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
// f32_f32_f32_f32
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionBackwardDataSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_data_xdl_bilinear_f32_instances
=
std
::
tuple
<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
// clang-format on
>
;
// f16_f16_f16_comp_f8
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionBackwardDataSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances
=
std
::
tuple
<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F32
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
BF8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F32
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
LoopScheduler
::
Default
,
BF8
,
F8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F32
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
BF8
,
F8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
ck
::
Tuple
<
ELayout
>
,
ELayout
,
F16
,
F16
,
F32
,
F32
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
LoopScheduler
::
Default
,
BF8
,
F8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bilinear_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bilinear_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bilinear_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bilinear_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
ck
::
Tuple
<
int8_t
>
,
int8_t
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
ck
::
Tuple
<
int8_t
>
,
int8_t
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
ck
::
Tuple
<
int8_t
>
,
int8_t
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
ck
::
Tuple
<
int8_t
>
,
int8_t
,
PassThrough
,
PassThrough
,
Bilinear
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Tuple
<
NDHWGC
>
,
NDHWGC
,
F16
,
F16
,
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Tuple
<
NDHWGC
>
,
NDHWGC
,
F32
,
F32
,
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Tuple
<
NDHWGC
>
,
NDHWGC
,
BF16
,
BF16
,
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
OutLayout
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
InDataType
,
typename
ComputeTypeA
,
typename
ComputeTypeB
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Tuple
<
InLayout
>
,
InLayout
,
OutDataType
,
WeiDataType
,
Tuple
<
InDataType
>
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeTypeA
,
ComputeTypeB
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Tuple
<
InLayout
>
,
InLayout
,
OutDataType
,
WeiDataType
,
Tuple
<
InDataType
>
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeTypeA
,
ComputeTypeB
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
)
{
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
ComputeTypeB
,
F16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
ComputeTypeB
,
F32
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
WeiDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
ComputeTypeB
,
BF16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
op_ptrs
);
}
#endif
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<
int8_t
>
,
int8_t
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
ComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
&&
DLayouts
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
DLayouts
>
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -97,6 +97,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
PassThrough
,
PassThrough
>>>&
instances
);
// bf16_inputA i8_inputB
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
I8
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
...
...
@@ -180,6 +209,24 @@ struct DeviceOperationInstanceFactory<
}
}
// bf16_i8_input
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -17,7 +17,32 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_permute_scale_f16_instances
(
#ifdef CK_ENABLE_FP16
void
add_device_permute_scale_1d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
1
>>>&
);
void
add_device_permute_scale_2d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
2
>>>&
);
void
add_device_permute_scale_3d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
3
>>>&
);
void
add_device_permute_scale_4d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
...
...
@@ -25,7 +50,50 @@ void add_device_permute_scale_f16_instances(
Scale
,
4
>>>&
);
void
add_device_permute_scale_f32_instances
(
void
add_device_permute_scale_5d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
5
>>>&
);
void
add_device_permute_scale_6d_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
6
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_permute_scale_1d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
1
>>>&
);
void
add_device_permute_scale_2d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
2
>>>&
);
void
add_device_permute_scale_3d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
3
>>>&
);
void
add_device_permute_scale_4d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
...
...
@@ -33,6 +101,23 @@ void add_device_permute_scale_f32_instances(
Scale
,
4
>>>&
);
void
add_device_permute_scale_5d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
5
>>>&
);
void
add_device_permute_scale_6d_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
PassThrough
,
element_wise
::
UnarySquare
,
Scale
,
6
>>>&
);
#endif
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
...
...
@@ -57,15 +142,107 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
if
constexpr
(
NumDim
==
1
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_1d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_1d_f16_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDim
==
2
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_2d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_2d_f16_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDim
==
3
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_3d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_3d_f16_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDim
==
4
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_4d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_4d_f16_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDim
==
5
)
{
add_device_permute_scale_f32_instances
(
op_ptrs
);
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_5d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_5d_f16_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
else
if
constexpr
(
NumDim
==
6
)
{
add_device_permute_scale_f16_instances
(
op_ptrs
);
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F32
>>
)
{
add_device_permute_scale_6d_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F16
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F16
>>
)
{
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
...
...
library/
src
/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.
c
pp
→
library/
include/ck/library
/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.
h
pp
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
// clang-format off
using
device_permute_scale_f16_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
1
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
8
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
4
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
2
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
using
device_permute_scale_f32_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
1
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
8
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
4
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
4
,
2
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
// clang-format on
void
add_device_permute_scale_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
4
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_permute_scale_f16_instances
{});
}
void
add_device_permute_scale_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
4
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_permute_scale_f32_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
// clang-format off
template
<
index_t
NDims
>
using
device_permute_scale_f16_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
1
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
8
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
4
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
2
,
ck
::
Sequence
<
2
>
,
ck
::
Sequence
<
1
>>
>
;
template
<
index_t
NDims
>
using
device_permute_scale_f32_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
1
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
8
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
4
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
Pass
,
UnaryOp
,
Scale
,
NDims
,
2
,
ck
::
Sequence
<
2
>
,
ck
::
Sequence
<
1
>>
>
;
// clang-format on
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
View file @
6d9a07d7
set
(
GEMM_SPLITK_INSTANCES
)
list
(
APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_SPLITK_INSTANCES
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
0 → 100644
View file @
6d9a07d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
PipelineVersion
PipVer
,
ck
::
LoopScheduler
LoopSche
>
using
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F8
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
1
,
8
,
8
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
16
,
16
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
,
F16
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F8
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
8
,
16
,
16
,
16
,
1
,
2
,
S
<
1
,
8
,
8
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
16
,
16
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
,
F16
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F8
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
4
,
S
<
1
,
8
,
8
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
16
,
16
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
F16
,
PipVer
,
LoopSche
,
F16
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// default
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmDefault
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
// MNKPadding
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
// KPadding
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmKPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
// MNPadding
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNPadding
,
ck
::
PipelineVersion
::
v2
,
ck
::
LoopScheduler
::
Default
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances
<
GemmMNPadding
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Default
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment