Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
129e58ae
Commit
129e58ae
authored
Jun 05, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
9bebfd42
cb0645be
Changes
188
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
625 additions
and
173 deletions
+625
-173
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp
...mm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp
+56
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
...le/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
+0
-26
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
...mm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_problem.hpp
include/ck_tile/ops/gemm/block/block_gemm_problem.hpp
+2
-2
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+8
-2
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+88
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+10
-14
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+16
-6
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+78
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+6
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+26
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
...on_instance/gpu/grouped_convolution_forward_convscale.hpp
+108
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+42
-8
library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
...nsor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
+6
-1
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+1
-107
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+59
-0
No files found.
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmASmemBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmASmemBRegCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
#endif
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
deleted
100644 → 0
View file @
9bebfd42
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Problem Description for BlockGemmASmemBSmemCRegV1
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmASmemBSmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_
areg_bsmem_creg_
problem.hpp
→
include/ck_tile/ops/gemm/block/block_gemm_problem.hpp
View file @
129e58ae
...
@@ -7,13 +7,13 @@
...
@@ -7,13 +7,13 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Problem Description for BlockGemm
ARegBSmemCReg
// Problem Description for BlockGemm
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
typename
BlockGemmShape_
>
struct
BlockGemm
ARegBSmemCReg
Problem
struct
BlockGemmProblem
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
129e58ae
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
...
@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using
WarpGemmMfmaF16F16F32M16N16K32
=
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
...
@@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
...
@@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K32
SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M
32N32K16
SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
2
>>
;
...
@@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
...
@@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
...
@@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
...
@@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M
16N16K32
SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M
32N32K16
SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
2
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
129e58ae
...
@@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
}
}
};
};
template
<
typename
WarpGemmAttributeMfmaImpl_
,
index_t
kKIter
,
index_t
SFactor_
=
2
>
struct
WarpGemmAtrributeMfmaIterateK_SwizzleA
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
using
ADataType
=
typename
Impl
::
ADataType
;
using
BDataType
=
typename
Impl
::
BDataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
ext_vector_t
<
ADataType
,
vector_traits
<
typename
Impl
::
AVecType
>::
vector_size
*
kKIter
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
Impl
::
kCMLane
,
SFactor
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
,
1
,
1
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
2
,
1
,
3
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
/
SFactor
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
*
SFactor
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
constexpr
auto
I0
=
number
<
0
>
{};
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_a
&>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
return
c_vec
;
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
129e58ae
...
@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx9__)
defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
...
@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx9__)
defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
...
@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx9__)
defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
...
@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx9__)
defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
...
@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
#elif defined(__gfx908__)
...
@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
#elif defined(__gfx908__)
...
@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
...
@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
129e58ae
...
@@ -35,14 +35,24 @@ template <ck::index_t NDimSpatial,
...
@@ -35,14 +35,24 @@ template <ck::index_t NDimSpatial,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
F8
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_f8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
129e58ae
...
@@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -421,7 +423,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -421,7 +423,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
129e58ae
...
@@ -114,7 +114,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
...
@@ -114,7 +114,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
NHWGC
,
GKYXC
,
GKYXC
,
...
@@ -205,7 +217,19 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
...
@@ -205,7 +217,19 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
NDHWGC
,
GKZYXC
,
GKZYXC
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ConvScale
=
ck
::
tensor_operation
::
element_wise
::
ConvScale
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScale
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
ConvScale
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
ConvScale
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
129e58ae
...
@@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
# Do not build DL instances if DL_KERNELS macro is not set
# Do not build DL instances if DL_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
endforeach
()
# Do not build XDL instances if gfx9 targets are not on the target list
# Do not build XDL instances if gfx9 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
INST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl instance
${
source
}
"
)
message
(
"removing xdl instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
# Do not build WMMA instances if gfx11 targets are not on the target list
# Do not build WMMA instances if gfx11 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
INST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma instance
${
source
}
"
)
message
(
"removing wmma instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
ARGN
}
)
set
(
INST_OBJ
)
foreach
(
source IN LISTS ARGN
)
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
string
(
APPEND offload_targets
"--offload-arch=
${
target
}
"
)
endforeach
()
set_source_files_properties
(
${
source
}
PROPERTIES COMPILE_FLAGS
${
offload_targets
}
)
list
(
APPEND INST_OBJ
${
source
}
)
endforeach
()
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
INST_OBJ
}
)
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
...
@@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list})
...
@@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list})
if
(
NOT DEFINED DTYPES
)
if
(
NOT DEFINED DTYPES
)
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
message
(
"quantization instances will not be built!"
)
message
(
"quantization instances will not be built!"
)
set
(
add_inst 0
)
set
(
add_inst 0
)
...
@@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list})
...
@@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list})
message
(
"Found only dl instances, but DL_KERNELS is not set. Skipping."
)
message
(
"Found only dl instances, but DL_KERNELS is not set. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl instances, but gfx9 is not on the targets list. Skipping."
)
message
(
"Found only xdl instances, but gfx9 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
))
message
(
"Found only wmma instances, but gfx11 is not on the targets list. Skipping."
)
message
(
"Found only wmma instances, but gfx11 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_DL_KERNELS"
)
AND
(
NOT DEFINED DL_KERNELS
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_DL_KERNELS"
)
AND
(
NOT DEFINED DL_KERNELS
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping."
)
message
(
"Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping."
)
message
(
"Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
message
(
"Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping."
)
message
(
"Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping."
)
set
(
add_inst 0
)
set
(
add_inst 0
)
endif
()
endif
()
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
View file @
129e58ae
...
@@ -2,9 +2,14 @@
...
@@ -2,9 +2,14 @@
set
(
GEMM_MULTI_ABD_INSTANCES
)
set
(
GEMM_MULTI_ABD_INSTANCES
)
list
(
APPEND GEMM_MULTI_ABD_INSTANCES
list
(
APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
)
)
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
AElementOp
,
Multiply
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
...
@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
...
@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
Interwave
>
{});
Interwave
>
{});
}
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
AElementOp
,
Multiply
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
FastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
FastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment