Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0eb75e21
Commit
0eb75e21
authored
Aug 17, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/moe
parents
1b4b640b
c8b6b642
Changes
200
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
456 additions
and
503 deletions
+456
-503
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
+6
-9
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+6
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+6
-6
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+18
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+59
-57
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+12
-12
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+13
-13
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+50
-49
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+79
-138
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+56
-129
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
+20
-23
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+7
-7
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
...ance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
+6
-6
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+4
-2
library/include/ck/library/utility/convolution_parameter.hpp
library/include/ck/library/utility/convolution_parameter.hpp
+30
-18
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+18
-3
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+48
-15
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
...ensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
+5
-0
No files found.
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
View file @
0eb75e21
...
...
@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
0eb75e21
...
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
...
...
@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
0eb75e21
...
...
@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
}
...
...
@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_a
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
// c += a * b
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
0eb75e21
...
...
@@ -15,7 +15,8 @@ template <typename AType,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
bool
TransposeC
,
bool
SwizzleA
=
false
>
struct
WarpGemmMfmaDispatcher
;
// clang-format off
...
...
@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
...
...
@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
...
...
@@ -58,8 +65,15 @@ template <typename AType,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
>::
Type
;
bool
TransposeC
,
bool
SwizzleA
=
false
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
,
SwizzleA
>::
Type
;
}
// namespace ck_tile
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
0eb75e21
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
{
constexpr
auto
input_offset_to_spatial
=
3
;
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
ck
::
long_
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
output_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
...
...
@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
for
(
long_
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
for
(
long_
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
for
(
long_
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
...
...
@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
...
...
@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
0eb75e21
...
...
@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
const
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
0eb75e21
...
...
@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
std
::
array
<
Tensor
<
InDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const
Tensor
<
InDataType
>&
in_n_c_hi_wi
,
Tensor
<
WeiDataType
>&
wei_k_c_y_x
,
const
Tensor
<
OutDataType
>&
out_n_k_ho_wo
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
0eb75e21
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
std
::
array
<
Tensor
<
WeiDataType
>
,
NumBElementwiseTensor
>&
elementwise_b_tensors_
;
const
std
::
array
<
Tensor
<
OutDataType
>
,
NumDElementwiseTensor
>&
elementwise_d_tensors_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_strides_
;
std
::
vector
<
ck
::
long_
index_t
>
conv_dilations_
;
std
::
vector
<
ck
::
long_
index_t
>
in_left_pads_
;
std
::
vector
<
ck
::
long_
index_t
>
in_right_pads_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
...
...
@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
0eb75e21
...
...
@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
...
...
@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
long_
index_t
>
conv_strides_
;
std
::
vector
<
long_
index_t
>
conv_dilations_
;
std
::
vector
<
long_
index_t
>
in_left_pads_
;
std
::
vector
<
long_
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
long_
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
...
...
@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
input_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
...
...
@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
long_
index_t
row
=
n
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
...
...
@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
...
...
@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
const
long_
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
long_
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
long_
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
long_
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
for
(
long_
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
for
(
long_
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
for
(
long_
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
for
(
long_
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
...
...
@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
ck
::
long_
index_t
G
=
arg
.
input_
.
GetLengths
()[
0
];
const
ck
::
long_
index_t
N
=
arg
.
input_
.
GetLengths
()[
1
];
const
ck
::
long_
index_t
C
=
arg
.
input_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
const
long_
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
long_
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
output_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
...
...
@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
0eb75e21
...
...
@@ -18,134 +18,82 @@ namespace device {
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
...
...
@@ -154,7 +102,7 @@ template <typename ADataType,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
SplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
using
DeviceOp
=
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
{
...
...
@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
0eb75e21
...
...
@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -309,11 +254,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
...
...
@@ -376,93 +316,98 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mnpadding
_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
comp_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_
mem_v1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_default_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
1
_kpadding_instances
(
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v
2
_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v1_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_
mem_v2_default
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_
comp_nkpadding
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_kpadding
_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_
bf16_bf16
_bf16_mk_
n
k_mn_mem_v
2_mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_
f8_f8
_bf16_mk_k
n
_mn_mem_v
1_
kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_kpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_mn
padding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_k
padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_
n
k_mn_
comp_m
nkpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_k
n
_mn_
mem_v2_
nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
mem_v1
_default_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
comp
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
mem_v1
_kpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_
comp
_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_
mnkpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_
default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v
2_default
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v
1_kpadding
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
kpadding
_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_
mn
kpadding_instances
(
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -532,28 +477,20 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_FP16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16)
&&
defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
@@ -562,21 +499,14 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -608,21 +538,14 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -684,51 +607,55 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16)
||
defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16)
&&
defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp
View file @
0eb75e21
...
...
@@ -3,7 +3,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_
ab
d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_
large_tensor_
cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
@@ -29,8 +29,6 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd3x3
=
ConvolutionForwardSpecialization
::
Filter3x3
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
...
...
@@ -39,16 +37,16 @@ template <index_t NDimSpatial,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_
me
rge
d_groups
_bf16_instances
=
std
::
tuple
<
using
device_grouped_conv_fwd_xdl_
la
rge
_tensor
_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
//
Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
1
6
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
1
6
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
1
6
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
6
4
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
1
28
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
@@ -58,16 +56,16 @@ template <index_t NDimSpatial,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_
me
rge
d_groups
_f16_instances
=
std
::
tuple
<
using
device_grouped_conv_fwd_xdl_
la
rge
_tensor
_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
//
generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
@@ -77,19 +75,18 @@ template <index_t NDimSpatial,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_
me
rge
d_groups
_f32_instances
=
std
::
tuple
<
using
device_grouped_conv_fwd_xdl_
la
rge
_tensor
_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultiple
AB
D_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
1
6
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
32
>
//
generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
1
28
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
0eb75e21
...
...
@@ -17,7 +17,7 @@
#endif
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_
me
rge
d_groups
.inc"
#include "grouped_convolution_forward_xdl_
la
rge
_tensor
.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
...
...
@@ -200,7 +200,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f32_instances
(
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
...
...
@@ -215,7 +215,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
...
...
@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_bf16_instances
(
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
...
...
@@ -291,7 +291,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
...
...
@@ -347,7 +347,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
...
...
@@ -364,7 +364,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_
me
rge
d_groups
.inc
→
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_
la
rge
_tensor
.inc
View file @
0eb75e21
...
...
@@ -10,7 +10,7 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_bf16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f32_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_i
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -91,7 +91,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_in
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
void
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
...
...
library/include/ck/library/utility/check_err.hpp
View file @
0eb75e21
...
...
@@ -146,7 +146,7 @@ check_err(const Range& out,
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
n
umeric
_l
imits
<
ranges
::
range_value_t
<
Range
>>::
m
in
();
double
max_err
=
N
umeric
L
imits
<
ranges
::
range_value_t
<
Range
>>::
M
in
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
...
...
@@ -178,7 +178,9 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
...
...
library/include/ck/library/utility/convolution_parameter.hpp
View file @
0eb75e21
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -31,23 +31,35 @@ struct ConvParam
const
std
::
vector
<
ck
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
right_pads
);
ck
::
index_t
num_dim_spatial_
;
ck
::
index_t
G_
;
ck
::
index_t
N_
;
ck
::
index_t
K_
;
ck
::
index_t
C_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
std
::
vector
<
ck
::
index_t
>
GetOutputSpatialLengths
()
const
;
ConvParam
(
ck
::
long_index_t
n_dim
,
ck
::
long_index_t
group_count
,
ck
::
long_index_t
n_batch
,
ck
::
long_index_t
n_out_channels
,
ck
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
long_index_t
>&
right_pads
);
ck
::
long_index_t
num_dim_spatial_
;
ck
::
long_index_t
G_
;
ck
::
long_index_t
N_
;
ck
::
long_index_t
K_
;
ck
::
long_index_t
C_
;
std
::
vector
<
ck
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck
::
long_index_t
>
GetOutputSpatialLengths
()
const
;
std
::
size_t
GetFlops
()
const
;
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
0eb75e21
...
...
@@ -96,9 +96,16 @@ struct HostTensorDescriptor
this
->
CalculateStrides
();
}
HostTensorDescriptor
(
const
std
::
initializer_list
<
ck
::
long_index_t
>&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
template
<
typename
Lengths
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>>>
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
||
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
ck
::
long_index_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
...
...
@@ -114,11 +121,19 @@ struct HostTensorDescriptor
{
}
HostTensorDescriptor
(
const
std
::
initializer_list
<
ck
::
long_index_t
>&
lens
,
const
std
::
initializer_list
<
ck
::
long_index_t
>&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
template
<
typename
Lengths
,
typename
Strides
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>>>
(
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>
)
||
(
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
ck
::
long_index_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
ck
::
long_index_t
>
)
>>
HostTensorDescriptor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
0eb75e21
...
...
@@ -64,6 +64,13 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build mha instances if gfx94 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
set
(
INST_OBJ
)
...
...
@@ -74,9 +81,11 @@ function(add_instance_library INSTANCE_NAME)
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
gfx1200 gfx1201
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
...
...
@@ -86,7 +95,29 @@ function(add_instance_library INSTANCE_NAME)
list
(
APPEND INST_OBJ
${
source
}
)
endforeach
()
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
INST_OBJ
}
)
# Allow comparing floating points directly in order to check sentinel values
if
(
${
INSTANCE_NAME
}
STREQUAL
"device_mha_instance"
)
if
(
NOT DEFINED FMHA_FWD_FAST_EXP2
)
set
(
FMHA_FWD_FAST_EXP2 true
)
endif
()
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
target_compile_options
(
device_mha_instance PRIVATE
${
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
}
)
endif
()
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
# flags to compress the library
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
message
(
"Adding --offload-compress flag for
${
INSTANCE_NAME
}
"
)
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
endif
()
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
set
(
result 0
)
...
...
@@ -286,20 +317,22 @@ if(CK_DEVICE_CONV_INSTANCES)
)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_mha_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/mha>
)
rocm_install
(
TARGETS device_mha_operations
EXPORT device_mha_operationsTargets
)
rocm_install
(
EXPORT device_mha_operationsTargets
FILE composable_kerneldevice_mha_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
set
(
gpu_list
${
INST_TARGETS
}
)
list
(
FILTER gpu_list INCLUDE REGEX
"^gfx94"
)
if
(
gpu_list
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
rocm_install
(
TARGETS device_mha_operations
EXPORT device_mha_operationsTargets
)
rocm_install
(
EXPORT device_mha_operationsTargets
FILE composable_kerneldevice_mha_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
endif
()
if
(
CK_DEVICE_CONTRACTION_INSTANCES
)
add_library
(
device_contraction_operations STATIC
${
CK_DEVICE_CONTRACTION_INSTANCES
}
)
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
0eb75e21
...
...
@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
View file @
0eb75e21
...
...
@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_ab_scale_instance
${
GEMM_AB_SCALE_INSTANCES
}
)
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment