Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2b8a9941
Commit
2b8a9941
authored
Oct 16, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
ce9d7c8d
707ad002
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
844 additions
and
944 deletions
+844
-944
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+55
-89
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+210
-628
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-0
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+0
-11
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+21
-21
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+52
-4
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
+5
-1
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
+62
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
+62
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
+83
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
+79
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+34
-34
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
...grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
+21
-35
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
...mm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
+51
-35
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
...grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
+17
-25
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
...mm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
+24
-25
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+6
-3
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+41
-19
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-3
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
+11
-11
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
2b8a9941
...
...
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number
<
num
>
{});
}
template
<
typename
T
>
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
...
...
@@ -159,94 +156,63 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid
);
});
if
constexpr
(
is_detected
<
has_vec_len
,
decltype
(
element_op_
)
>::
value
)
{
constexpr
auto
elem_op_vec_len
=
decltype
(
element_op_
)
::
vec_len
;
static_assert
(
is_same
<
remove_cvref_t
<
decltype
(
elem_op_vec_len
)
>
,
index_t
>::
value
,
"vec_len in element_op_ type is not index_t"
);
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
static_assert
(
elem_op_vec_len
==
1
||
elem_op_vec_len
==
2
||
elem_op_vec_len
==
4
||
elem_op_vec_len
==
8
,
"vec_len in element_op_ must be 1, 2, 4, 8"
);
static_assert
(
SrcScalarPerVector
%
elem_op_vec_len
==
0
,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!"
);
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
else
{
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
2b8a9941
...
...
@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC
=
3
,
};
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
int8_t
,
N
>::
type
amd_buffer_load_impl_raw
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
if
constexpr
(
N
==
1
)
{
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
bit_cast
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
else
if
constexpr
(
N
==
2
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
return
bit_cast
<
int8x2_t
>
(
tmp
);
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
else
if
constexpr
(
N
==
4
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
return
bit_cast
<
int8x4_t
>
(
tmp
);
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
N
==
8
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x2_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
return
bit_cast
<
int8x8_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
16
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
32
)
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
8
>
tmp
;
return
bit_cast
<
int8x4_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
8
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
16
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x32_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
64
)
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp2
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp3
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
#endif
}
vector_type
<
int32_t
,
16
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
2
>
{})
=
tmp2
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
3
>
{})
=
tmp3
;
return
bit_cast
<
int8x64_t
>
(
tmp
);
}
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl_raw
(
const
typename
vector_type
<
int8_t
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
N
==
1
)
{
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
else
if
constexpr
(
N
==
2
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
else
if
constexpr
(
N
==
4
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
else
if
constexpr
(
N
==
8
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
static_cast
<
index_t
>
(
coherence
));
}
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
N
==
16
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
else
if
constexpr
(
N
==
32
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
vector_type
<
int32_t
,
8
>
tmp
{
bit_cast
<
int32x8_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
64
)
{
vector_type
<
int32_t
,
16
>
tmp
{
bit_cast
<
int32x16_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
8
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
12
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
...
...
@@ -1127,54 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#endif
}
...
...
@@ -1232,62 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
...
...
include/ck/utility/is_detected.hpp
View file @
2b8a9941
...
...
@@ -31,4 +31,13 @@ struct nonesuch
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/sequence.hpp
View file @
2b8a9941
...
...
@@ -897,14 +897,3 @@ template <index_t NSize, index_t I>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
}
// namespace ck
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
using
S
=
ck
::
Sequence
<
Is
...
>
;
os
<<
"{"
;
ck
::
static_for
<
0
,
S
::
Size
()
-
ck
::
Number
<
1
>
{},
1
>
{}(
[
&
](
auto
i
)
{
os
<<
S
::
At
(
i
).
value
<<
", "
;
});
os
<<
S
::
At
(
S
::
Size
()
-
ck
::
Number
<
1
>
{}).
value
<<
"}"
;
return
os
;
}
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
2b8a9941
...
...
@@ -14,27 +14,27 @@ namespace ck {
namespace
tensor_operation
{
namespace
host
{
//
/
//
/
@brief Reference implementation for forward convolution.
//
/
//
/
@paragraph
//
/
Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
//
/
Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
//
/
as long as dimensions in tensor descriptor is in GNCHW order
//
/
//
/
@tparam InDataType Input tensor data type.
//
/
@tparam WeiDataType Weights tensor data type.
//
/
@tparam OutDataType Output tensor data type.
//
/
@tparam InElementwiseOperation Functor for input tensor elementwise
//
/
operation.
//
/
@tparam WeiElementwiseOperation Functor for weights tensor elementwise
//
/
operation.
//
/
@tparam NDimSpatial Number of spatial dimensions.
//
/
//
/
input descriptor in [G, N, C, Do, Ho, Wo] order
//
/
weight descriptor in [G, K, C, Z, Y, X] order
//
/
output descriptor in [G, N, K, Di, Hi, Wi] order
//
/
phyiscal layout is irrelavent
//
// @brief Reference implementation for forward convolution.
//
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NDimSpatial Number of spatial dimensions.
//
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
2b8a9941
...
...
@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
);
#endif
template
<
typename
ADataType
,
...
...
@@ -105,7 +125,8 @@ template <typename ADataType,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
,
typename
ComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
...
...
@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ComputeType
>>
{
using
DeviceOp
=
DeviceGemmSplitK
<
ALayout
,
BLayout
,
...
...
@@ -125,7 +147,8 @@ struct DeviceOperationInstanceFactory<
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ComputeType
>
;
static
auto
GetInstances
()
{
...
...
@@ -158,7 +181,7 @@ struct DeviceOperationInstanceFactory<
#endif
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
f8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
View file @
2b8a9941
...
...
@@ -15,6 +15,10 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
)
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
0 → 100644
View file @
2b8a9941
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
0 → 100644
View file @
2b8a9941
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
0 → 100644
View file @
2b8a9941
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
1
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
F8
>
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
192
,
64
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
192
,
32
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
0 → 100644
View file @
2b8a9941
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
1
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
2
,
F8
>
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
2b8a9941
...
...
@@ -33,26 +33,26 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// a[m, k] * b[k, n] = e[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
@@ -60,20 +60,20 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
2b8a9941
...
...
@@ -33,41 +33,27 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// a[m, k] * b[k, n] = e[m, n]
using
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| PipelineVersion|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Pipline v1
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
// Pipline v2
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Currently AK1 must equal BK1 !
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
View file @
2b8a9941
...
...
@@ -31,41 +31,57 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| PipelineVersion|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Pipline v1
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
// Pipline v2
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Currently AK1 must equal BK1 !
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
2b8a9941
...
...
@@ -32,31 +32,23 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// a[m, k] * b[n, k] = e[m, n]
using
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| PipelineVersion|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Pipeline v1
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
// Pipeline v2
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v2
>
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
View file @
2b8a9941
...
...
@@ -31,31 +31,30 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| PipelineVersion|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Pipeline v1
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, ck::PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
,
// Pipeline v2
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
//DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, ck::PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
ck
::
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v2
>
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
256
,
32
,
8
,
8
,
32
,
32
,
1
,
4
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
2b8a9941
...
...
@@ -30,7 +30,8 @@ template <typename ADataType,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
,
typename
ComputeType
=
CDataType
>
bool
profile_gemm_splitk_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -103,7 +104,8 @@ bool profile_gemm_splitk_impl(int do_verification,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
@@ -120,7 +122,8 @@ bool profile_gemm_splitk_impl(int do_verification,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
2b8a9941
...
...
@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
};
#define OP_NAME "gemm_splitk"
...
...
@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
auto
c_layout
,
auto
compute_type
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
...
...
@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
using
ComputeType
=
decltype
(
compute_type
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
...
...
@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
CLayout
,
ComputeType
>
(
do_verification
,
init_method
,
do_log
,
...
...
@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F8
{});
}
#endif
else
...
...
script/cmake-ck-dev.sh
View file @
2b8a9941
...
...
@@ -12,9 +12,7 @@ cmake
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx90
a
"
\
-D
GPU_TARGETS
=
"gfx90
8;gfx90a;gfx940
"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
View file @
2b8a9941
...
...
@@ -4,7 +4,7 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
constexpr
int
N
=
768
;
constexpr
int
K
=
1088
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -17,9 +17,9 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
TEST_P
(
RRR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
3
,
4
,
5
};
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
constexpr
int
N
=
768
;
constexpr
int
K
=
1088
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -34,7 +34,7 @@ TEST_P(RRR_F16_F16_F16, MidCases)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
constexpr
int
N
=
768
;
constexpr
int
K
=
1088
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -49,7 +49,7 @@ TEST_P(RRR_F16_F16_F16, Regular)
{
const
std
::
vector
<
int
>
Ms
{
64
,
128
,
256
};
constexpr
int
N
=
768
;
constexpr
int
K
=
64
0
;
constexpr
int
K
=
32
0
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -79,7 +79,7 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
constexpr
int
N
=
768
;
constexpr
int
K
=
1088
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -91,9 +91,9 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
3
,
4
,
5
};
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
constexpr
int
N
=
768
;
constexpr
int
K
=
1088
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -123,7 +123,7 @@ TEST_P(RCR_F16_F16_F16, Regular)
{
const
std
::
vector
<
int
>
Ms
{
32
,
64
,
128
,
256
};
constexpr
int
N
=
768
;
constexpr
int
K
=
64
0
;
constexpr
int
K
=
32
0
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
@@ -151,9 +151,9 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
TEST_P
(
RRR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
constexpr
int
N
=
768
;
constexpr
int
K
=
8192
;
constexpr
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment