Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ea5be216
Commit
ea5be216
authored
Aug 23, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
e2eb0418
25935b57
Changes
168
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
524 additions
and
335 deletions
+524
-335
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
...emm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
+33
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+6
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+6
-6
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+18
-4
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+79
-138
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+56
-129
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+37
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
...on_instance/gpu/grouped_convolution_forward_convscale.hpp
+83
-3
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
...stance/gpu/grouped_convolution_forward_convscale_relu.hpp
+83
-1
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
...k/library/tensor_operation_instance/gpu/permute_scale.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
...ance/gpu/permute_scale/device_permute_scale_instances.hpp
+52
-6
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
...uce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
+18
-9
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+7
-4
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
...ensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
...ration_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
+3
-4
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+1
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
+11
-11
No files found.
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp
0 → 100644
View file @
ea5be216
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBRegCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
ea5be216
...
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
...
@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
"wrong!"
);
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
//
KPerBlock == BlockGemmShape::kK,
KPerBlock
==
BlockGemmShape
::
kK
,
//
"wrong!");
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
View file @
ea5be216
...
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
...
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
"wrong!"
);
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr
index_t
NPerBlock
=
BBlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr
index_t
KPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
//
KPerBlock == BlockGemmShape::kK,
KPerBlock
==
BlockGemmShape
::
kK
,
//
"wrong!");
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
ea5be216
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using
WarpGemmMfmaF16F16F32M16N16K32
=
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
...
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
...
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
ea5be216
...
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
});
}
}
...
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b
// c = a * b
auto
c_vec
=
Impl
{}(
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_a
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
// c += a * b
// c += a * b
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
});
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
ea5be216
...
@@ -15,7 +15,8 @@ template <typename AType,
...
@@ -15,7 +15,8 @@ template <typename AType,
index_t
MPerWave
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
bool
TransposeC
,
bool
SwizzleA
=
false
>
struct
WarpGemmMfmaDispatcher
;
struct
WarpGemmMfmaDispatcher
;
// clang-format off
// clang-format off
...
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
...
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
...
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
...
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
...
@@ -58,8 +65,15 @@ template <typename AType,
...
@@ -58,8 +65,15 @@ template <typename AType,
index_t
MPerWave
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
bool
TransposeC
,
using
WarpGemmMfmaDispatcher
=
typename
impl
::
bool
SwizzleA
=
false
>
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
>::
Type
;
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
,
SwizzleA
>::
Type
;
}
// namespace ck_tile
}
// namespace ck_tile
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
ea5be216
...
@@ -18,134 +18,82 @@ namespace device {
...
@@ -18,134 +18,82 @@ namespace device {
namespace
instance
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -154,7 +102,7 @@ template <typename ADataType,
...
@@ -154,7 +102,7 @@ template <typename ADataType,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
SplitK
<
ALayout
,
ALayout
,
BLayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
using
DeviceOp
=
BLayout
,
DeviceGemmMultipleDSplitK
<
ALayout
,
Tuple
<
Row
,
Col
>
,
BLayout
,
CLayout
,
Tuple
<
Row
,
Col
>
,
ADataType
,
CLayout
,
BDataType
,
ADataType
,
Tuple
<
F32
,
F32
>
,
BDataType
,
CDataType
,
Tuple
<
F32
,
F32
>
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
ea5be216
...
@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
...
@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance
...
@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance
...
@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
...
@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
...
@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
...
@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
...
@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
...
@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -309,11 +254,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
...
@@ -309,11 +254,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
...
@@ -376,93 +316,98 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc
...
@@ -376,93 +316,98 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mnpadding
_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_default_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v1_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v2_default
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_nkpadding
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_kpadding
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_kpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_mn
padding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_k
padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_m
nkpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_
nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
mem_v1
_default_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
comp
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
mem_v1
_kpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
comp
_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_
mnkpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_
default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v
2_default
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v
1_kpadding
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
kpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
...
@@ -532,28 +477,20 @@ struct DeviceOperationInstanceFactory<
...
@@ -532,28 +477,20 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
}
}
#endif
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
)
{
{
...
@@ -562,21 +499,14 @@ struct DeviceOperationInstanceFactory<
...
@@ -562,21 +499,14 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -608,21 +538,14 @@ struct DeviceOperationInstanceFactory<
...
@@ -608,21 +538,14 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -684,51 +607,55 @@ struct DeviceOperationInstanceFactory<
...
@@ -684,51 +607,55 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
}
}
#endif
#endif
#if(defined(CK_ENABLE_BF16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16)
&&
defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
is_same_v
<
CDataType
,
bhalf_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
View file @
ea5be216
...
@@ -184,6 +184,43 @@ using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple<
...
@@ -184,6 +184,43 @@ using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
View file @
ea5be216
...
@@ -8,9 +8,7 @@
...
@@ -8,9 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -177,6 +175,88 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -177,6 +175,88 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
};
};
using
CombConvScale
=
ck
::
tensor_operation
::
element_wise
::
ScaleScalePass
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScale
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScale
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScale
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
View file @
ea5be216
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/
unary
_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/
combined
_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
...
@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
}
}
};
};
using
CombConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ScaleScaleRelu
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
View file @
ea5be216
...
@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
...
@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
element_wise
::
Scale
,
6
>>>&
);
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
#endif
#ifdef CK_ENABLE_FP8
void
add_device_permute_scale_6d_f32_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
template
<
typename
InDataTypeTuple
,
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
...
@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
...
@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F8
>>
)
{
add_device_permute_scale_6d_f32_f8_instances
(
op_ptrs
);
}
#endif
#endif
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
View file @
ea5be216
...
@@ -10,6 +10,7 @@ namespace tensor_operation {
...
@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
...
@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
#if 0
#if 0
// Disabled instances to improve compilation time
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
...
@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
...
@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
>
;
template
<
index_t
NDims
,
template
<
index_t
NDims
,
...
@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
...
@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
#if 0
#if 0
// Disabled instances to improve compilation time
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
...
@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
...
@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
#endif
#endif
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
...
@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
...
@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
>
;
#ifdef CK_ENABLE_FP8
template
<
index_t
NDims
,
typename
ElementwiseOp
>
using
device_permute_scale_f32_f8_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
256
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
256
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
128
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
64
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#endif
// clang-format on
// clang-format on
}
// namespace instance
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
View file @
ea5be216
...
@@ -14,15 +14,24 @@ namespace device {
...
@@ -14,15 +14,24 @@ namespace device {
namespace
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
// clang-format on
// clang-format on
}
// namespace instance
}
// namespace instance
...
...
library/include/ck/library/utility/check_err.hpp
View file @
ea5be216
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -146,7 +146,7 @@ check_err(const Range& out,
...
@@ -146,7 +146,7 @@ check_err(const Range& out,
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
std
::
n
umeric
_l
imits
<
ranges
::
range_value_t
<
Range
>>::
m
in
();
double
max_err
=
N
umeric
L
imits
<
ranges
::
range_value_t
<
Range
>>::
M
in
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
...
@@ -178,7 +178,9 @@ check_err(const Range& out,
...
@@ -178,7 +178,9 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
#endif
...
@@ -270,7 +272,8 @@ check_err(const Range& out,
...
@@ -270,7 +272,8 @@ check_err(const Range& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
ea5be216
...
@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
...
@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
list
(
APPEND GEMM_INSTANCES
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
View file @
ea5be216
...
@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
...
@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
)
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_ab_scale_instance
${
GEMM_AB_SCALE_INSTANCES
}
)
add_instance_library
(
device_gemm_ab_scale_instance
${
GEMM_AB_SCALE_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
View file @
ea5be216
...
@@ -4,14 +4,13 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
...
@@ -4,14 +4,13 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
ea5be216
...
@@ -46,8 +46,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
...
@@ -46,8 +46,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
View file @
ea5be216
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment