Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
df467969
Commit
df467969
authored
Dec 04, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
e9047ab9
ff24b537
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
323 additions
and
68 deletions
+323
-68
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
...ration/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
+142
-5
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+10
-0
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+88
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp
..._instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp
+23
-20
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
...shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
+15
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
+2
-3
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...wd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+37
-36
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
View file @
df467969
...
...
@@ -7,6 +7,20 @@
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
lds_direct_load
{
__device__
void
sched_barrier
()
{
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
// When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
// `sched_barrier` is necessary to make sure no instructions that use the loaded memory
// are scheduled by the compiler before the `waitcnt` instruction.
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
// namespace lds_direct_load
namespace
ck
{
template
<
index_t
NumPrefetch
>
...
...
@@ -17,7 +31,6 @@ template <>
struct
GridwiseGemmPipeline_v4
<
1
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
...
...
@@ -31,13 +44,13 @@ struct GridwiseGemmPipeline_v4<1>
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockBuffer
s
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
s
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
...
...
@@ -45,18 +58,22 @@ struct GridwiseGemmPipeline_v4<1>
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
s
&
a_block_buf
s
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
s
&
b_block_buf
s
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
static_assert
(
ABlockBuffers
::
Size
()
==
1
&&
BBlockBuffers
::
Size
()
==
1
);
auto
&
a_block_buf
=
a_block_bufs
.
At
(
I0
);
auto
&
b_block_buf
=
b_block_bufs
.
At
(
I0
);
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
...
...
@@ -74,10 +91,12 @@ struct GridwiseGemmPipeline_v4<1>
do
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
...
...
@@ -92,10 +111,128 @@ struct GridwiseGemmPipeline_v4<1>
// tail
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
// 2-stages prefetch
template
<
>
struct
GridwiseGemmPipeline_v4
<
2
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffers
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffers
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffers
&
a_block_bufs
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffers
&
b_block_bufs
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
static_assert
(
ABlockBuffers
::
Size
()
==
2
&&
BBlockBuffers
::
Size
()
==
2
);
auto
&
a_block_buf1
=
a_block_bufs
.
At
(
I0
);
auto
&
a_block_buf2
=
a_block_bufs
.
At
(
I1
);
auto
&
b_block_buf1
=
b_block_bufs
.
At
(
I0
);
auto
&
b_block_buf2
=
b_block_bufs
.
At
(
I1
);
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf2
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf2
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf1
,
b_block_buf1
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf2
,
b_block_buf2
,
c_thread_buf
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf2
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf2
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf1
,
b_block_buf1
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf2
,
b_block_buf2
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/utility/amd_buffer_addressing.hpp
View file @
df467969
...
...
@@ -972,6 +972,15 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
...
...
@@ -979,6 +988,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
#endif
}
}
// namespace ck
include/ck/utility/tuple_helper.hpp
View file @
df467969
...
...
@@ -5,6 +5,7 @@
#include "functional4.hpp"
#include "tuple.hpp"
#include "is_detected.hpp"
namespace
ck
{
...
...
@@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
ty
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
// Support any number of tuples to concat (also 1)
template
<
typename
...
X
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
)
{
return
tx
;
}
template
<
typename
...
X
,
typename
...
Tuples
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
{
return
concat_tuple
(
tx
,
concat_tuple
(
tuples
...));
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
...
...
@@ -78,4 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<>&
element
)
{
return
element
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
T
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
T
&
element
)
{
return
make_tuple
(
element
);
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<
Ts
...
>&
tuple
)
{
if
constexpr
(
Depth
==
MaxDepth
)
{
return
tuple
;
}
else
{
return
unpack
(
[
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
UnrollNestedTuple
<
Depth
+
1
,
MaxDepth
>
(
ts
)...);
},
tuple
);
}
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleReverse
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
Number
<
Tuple
<
Ts
...
>::
Size
()
-
i
-
1
>
;
return
tuple
.
At
(
Idx
{});
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Reduce tuple values in specific range using Function
template
<
index_t
Idx
,
index_t
End
,
typename
F
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleReduce
(
F
&&
f
,
const
Tuple
<
Ts
...
>&
tuple
)
{
static_assert
(
Idx
<
End
,
"Wrong parameters for TupleReduce"
);
if
constexpr
(
Idx
+
1
==
End
)
{
return
tuple
.
At
(
Number
<
Idx
>
{});
}
else
{
return
f
(
tuple
.
At
(
Number
<
Idx
>
{}),
TupleReduce
<
Idx
+
1
,
End
>
(
f
,
tuple
));
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp
View file @
df467969
...
...
@@ -23,19 +23,20 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
#ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<>
,
BF16
,
ScaleAdd
,
ScaleAdd
,
PassThrough
>>>&
instances
);
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347
// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// ck::Tuple<>,
// NDHWGK,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<>,
// BF16,
// ScaleAdd,
// ScaleAdd,
// PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
...
...
@@ -151,13 +152,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
Tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
>>
&&
is_same_v
<
WeiDataType
,
ck
::
Tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
>>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
}
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347
// if constexpr(is_same_v<InDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
// is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
// is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType,
// ck::bhalf_t>)
// {
// add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
// }
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
Tuple
<
int8_t
,
int8_t
>>
&&
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
df467969
...
...
@@ -35,7 +35,21 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
64
,
32
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
View file @
df467969
...
...
@@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
View file @
df467969
...
...
@@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
View file @
df467969
...
...
@@ -31,7 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
View file @
df467969
...
...
@@ -24,8 +24,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
...
...
@@ -34,7 +33,7 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
Gemm
MNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
Gemm
Default
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
df467969
...
...
@@ -9,42 +9,43 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<>
,
BF16
,
ScaleAdd
,
ScaleAdd
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347
// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// ck::Tuple<>,
// NDHWGK,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<>,
// BF16,
// ScaleAdd,
// ScaleAdd,
// PassThrough>>>& instances)
// {
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwdDefault>{});
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwd1x1P0>{});
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwd1x1S1P0>{});
// }
}
// namespace instance
}
// namespace device
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment