Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
55cdf2b9
Commit
55cdf2b9
authored
Oct 27, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion
parents
4b59b5c9
b098b71b
Changes
132
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1458 additions
and
7 deletions
+1458
-7
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+1
-1
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+171
-1
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+8
-0
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
+81
-0
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
...de/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
+16
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+8
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+113
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
...e/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
+22
-0
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+10
-0
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+166
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+123
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+63
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
...pk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
+46
-0
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
...library/reference_tensor_operation/gpu/reference_gemm.hpp
+5
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+105
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
...wd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
+35
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
...v_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
+179
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
...n_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
+278
-0
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
55cdf2b9
...
@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
)
)
;
},
},
var
);
var
);
...
...
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
55cdf2b9
...
@@ -4,9 +4,14 @@
...
@@ -4,9 +4,14 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include <tuple>
namespace
ck_tile
{
namespace
ck_tile
{
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
...
@@ -22,7 +27,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
...
@@ -22,7 +27,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
const
auto
ps_idx
=
make_array
<
index_t
>
(
get_block_id
(),
get_lane_id
());
const
auto
ps_idx
=
detail
::
get_partition_index
(
acc_tensor
.
get_tile_distribution
());
const
auto
rs_idx
=
acc_tensor
.
get_tile_distribution
().
calculate_rs_index_from_ps_index
(
ps_idx
);
const
auto
rs_idx
=
acc_tensor
.
get_tile_distribution
().
calculate_rs_index_from_ps_index
(
ps_idx
);
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
...
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
...
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
});
}
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
block_tile_reduce_xor_sync
(
AccDistributedTensor_
&
acc_tensor
,
const
ReduceFunc
&
reduce_func
)
{
using
Dstr
=
typename
AccDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local
=
acc_tensor
.
get_thread_buffer
()[
i
];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
__lane_id
()
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
}
});
acc_tensor
.
get_thread_buffer
()(
i
)
=
v_local
;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template
<
typename
AccDistributedTensor_
,
template
<
typename
AccDistributedTensor_
,
typename
InDistributedTensor_
,
typename
InDistributedTensor_
,
...
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
...
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
#endif
}
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template
<
typename
AccDataType_
,
template
<
typename
AccDataType_
,
typename
InDistributedTensor_
,
typename
InDistributedTensor_
,
index_t
...
InReduceDims
,
index_t
...
InReduceDims
,
...
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
...
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return
acc_tensor
;
return
acc_tensor
;
}
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template
<
typename
InDistributedTensor_
>
struct
BlockReduce2D
{
using
InDistributedTensor
=
remove_cvref_t
<
InDistributedTensor_
>
;
using
InDataType
=
typename
InDistributedTensor
::
DataType
;
CK_TILE_HOST_DEVICE
BlockReduce2D
(
const
InDistributedTensor
&
t_
,
const
InDataType
&
reduce_init_
)
:
t
(
t_
),
reduce_init
(
reduce_init_
)
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
MakeDstBlockTile
()
const
{
using
ReduceDim
=
sequence
<
1
>
;
// hard coded
constexpr
auto
acc_dstr
=
make_static_tile_distribution
(
ck_tile
::
detail
::
make_reduce_tile_distribution_encoding
(
InDistributedTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
ReduceDim
{}));
return
make_static_distributed_tensor
<
InDataType
>
(
acc_dstr
);
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE
constexpr
auto
get_reduce_length_y
()
const
{
constexpr
auto
spans
=
InDistributedTensor
::
get_distributed_spans
();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template
<
typename
ReduceFunc
,
typename
ReduceSyncFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ReduceFunc
&
reduce_func
,
const
ReduceSyncFunc
&
reduce_sync_func
,
ReducePacksPerXDim
=
{})
const
{
constexpr
auto
spans
=
InDistributedTensor
::
get_distributed_spans
();
constexpr
auto
row_y_unpacks
=
[
&
]()
{
constexpr
auto
row_y_lengths
=
typename
decltype
(
spans
[
number
<
1
>
{}])
::
Impl
{};
constexpr
auto
row_y_size
=
reduce_on_sequence
(
row_y_lengths
,
multiplies
{},
number
<
1
>
{});
constexpr
auto
row_y_packs
=
ReducePacksPerXDim
{}.
at
(
number
<
1
>
{});
static_assert
(
row_y_size
%
row_y_packs
==
0
);
constexpr
auto
row_y_slice_size
=
row_y_size
/
row_y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
row_y_lengths
,
number
<
row_y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}();
auto
acc_tensor
=
MakeDstBlockTile
();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
acc_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
acc
=
acc_tensor
[
acc_dstr_idx
];
sweep_tile_uspan
(
spans
[
number
<
1
>
{}],
[
&
](
auto
...
dstr_idx_i1
)
{
acc
=
reduce_func
(
acc
,
t
[
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
)]...);
},
row_y_unpacks
);
acc_tensor
(
acc_dstr_idx
)
=
acc
;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync
(
acc_tensor
,
reduce_sync_func
);
return
acc_tensor
;
}
template
<
typename
ReduceFunc
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ReduceFunc
&
reduce_func
)
const
{
return
operator
()(
reduce_func
,
reduce_func
);
}
InDistributedTensor
t
;
InDataType
reduce_init
;
};
// deduction guide
template
<
typename
T
>
CK_TILE_HOST_DEVICE_EXTERN
BlockReduce2D
(
const
T
&
,
const
typename
T
::
DataType
&
)
->
BlockReduce2D
<
T
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/softmax.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
namespace
ck_tile
{
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockSoftmax2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
DistributedTensor
&
y
,
number
<
dim
>
=
{})
{
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
#if _BLOCK_SOFTMAX_USE_UNPACK2
const
auto
f_max3
=
[](
auto
e0
,
auto
e1
,
auto
e2
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, %2, %3"
:
"=v"
(
rtn
)
:
"v"
(
e0
),
"v"
(
e1
),
"v"
(
e2
));
return
rtn
;
};
const
auto
f_sum3
=
[](
auto
e0
,
auto
e1
,
auto
e2
)
{
return
e0
+
e1
+
e2
;
};
#endif
// compute row max
auto
reduce_row_max
=
BlockReduce2D
{
x
,
-
numeric
<
DataType
>::
infinity
()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto
row_max
=
reduce_row_max
(
f_max3
,
f_max
,
sequence
<
1
,
2
>
{});
#else
auto
row_max
=
reduce_row_max
(
f_max
);
#endif
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
y
(
idx
)
=
exp
(
x
[
idx
]
-
row_max
[
row_id
]);
});
// compute row sum
auto
reduce_row_sum
=
BlockReduce2D
<
decltype
(
y
)
>
{
y
,
DataType
{
0
}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto
row_sum
=
reduce_row_sum
(
f_sum3
,
f_sum
,
sequence
<
1
,
2
>
{});
#else
auto
row_sum
=
reduce_row_sum
(
f_sum
);
#endif
// reciprocal
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
sweep_tile
(
row_sum
,
[
&
](
auto
idx
)
{
r
(
idx
)
=
DataType
{
1
}
/
row_sum
(
idx
);
});
// scale
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
y
(
idx
)
=
y
(
idx
)
*
r
(
row_id
);
});
}
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
decltype
(
auto
)
operator
()(
const
DistributedTensor
&
x
,
number
<
dim
>
=
{})
{
auto
y
=
DistributedTensor
{};
// distributed tensor
operator
()(
x
,
y
,
number
<
dim
>
{});
return
y
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
>
struct
BlockSoftmax2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockTopkStream2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
using
IndexType
=
typename
Problem
::
IndexType
;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct
ArgmaxPacket
{
DataType
arg
;
index_t
value
;
};
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
const
OutWindow
&
out_window
,
const
IdxWindow
&
idx_window
,
index_t
k
,
number
<
dim
>
=
{})
{
OutWindow
out_window_tmp
=
out_window
;
IdxWindow
idx_window_tmp
=
idx_window
;
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
static_assert
(
std
::
is_same_v
<
typename
IdxWindow
::
DataType
,
IndexType
>
);
DistributedTensor
x_tmp
=
x
;
constexpr
auto
dst_dist
=
typename
IdxWindow
::
TileDstr
{};
// argmax for topk
const
auto
f_argmax
=
[](
ArgmaxPacket
e0
,
ArgmaxPacket
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
for
(
index_t
i_k
=
0
;
i_k
<
k
;
i_k
++
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
auto
packet
=
[
&
]()
{
auto
tmp
=
make_static_distributed_tensor
<
ArgmaxPacket
>
(
x
.
get_tile_distribution
());
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tmp
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
t
;
t
.
arg
=
x_tmp
(
i_j_idx
);
// !!! we reference x here
t
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
tmp
(
i_j_idx
)
=
t
;
});
});
return
tmp
;
}();
auto
argmax_init
=
ArgmaxPacket
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
ArgmaxPacket
>
(
packet
,
sequence
<
1
>
{},
f_argmax
,
argmax_init
);
block_tile_reduce_xor_sync
(
r
,
f_argmax
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
IndexType
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
// update value
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
x
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
auto
col_id
=
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
x_tmp
(
i_j_idx
)
=
(
col_id
==
r
(
i_j_idx
).
value
)
?
-
numeric
<
DataType
>::
infinity
()
:
x_tmp
(
i_j_idx
);
});
});
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
store_tile
(
out_window_tmp
,
o
);
store_tile
(
idx_window_tmp
,
i
);
}
move_tile_window
(
out_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
DataType_
,
typename
IndexType_
,
index_t
ColLanes_
>
struct
BlockTopkStream2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
ColLanes
=
ColLanes_
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
TopkSoftmaxHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
template
<
typename
Pipeline_
>
struct
TopkSoftmaxKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
InputType
=
typename
Problem
::
InputType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
IndexType
=
typename
Problem
::
IndexType
;
struct
TopkSoftmaxKargs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Hargs
=
TopkSoftmaxHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
if
constexpr
(
Problem
::
LaunchType
>
0
)
{
int
num_cu
=
[
&
]()
{
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
HIP_CHECK_ERROR
(
hipGetDevice
(
&
dev
));
HIP_CHECK_ERROR
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
return
dev_prop
.
multiProcessorCount
;
}();
return
dim3
(
num_cu
*
Problem
::
LaunchType
);
}
else
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
k
.
stride_input
=
h
.
stride_input
;
k
.
stride_output
=
h
.
stride_output
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
if
(
block_row_id
>
kargs
.
num_rows
)
return
;
index_t
block_os_inp
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_input
);
index_t
block_os_out
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_output
);
index_t
num_rows_rem
=
__builtin_amdgcn_readfirstlane
(
kargs
.
num_rows
-
block_row_id
);
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
block_os_inp
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_input
,
make_tuple
(
num_rows_rem
,
kargs
.
num_experts
),
make_tuple
(
kargs
.
stride_input
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
0
,
1
>
{});
// out-most dim no need pad(leverage oob)
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
{
0
,
0
});
}();
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_output
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_indices
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
num_rows
,
kargs
.
num_experts
,
kargs
.
topk
,
block_row_id
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
struct
TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
WeightType
=
typename
Problem
::
WeightType
;
template
<
typename
InputWindow
,
typename
OutputWindow
,
typename
IndexWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
index_t
rows
,
index_t
experts
,
index_t
k
,
index_t
block_row_id
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto
inp_win
=
make_tile_window_linear_raw
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#else
auto
inp_win
=
make_tile_window_linear
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#endif
auto
out_win
=
make_tile_window_linear
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
idx_win
=
make_tile_window_linear
(
idx_window
.
get_bottom_tensor_view
(),
idx_window
.
get_window_lengths
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
const
index_t
grid_rows_per_loop
=
gridDim
.
x
*
Problem
::
RowsPerBlock
;
while
(
1
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
#else
auto
x
=
load_tile
(
inp_win
);
#endif
// cast and pad input data
auto
w
=
[
&
]()
{
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto
w_
=
make_static_distributed_tensor
<
WeightType
>
(
x
.
get_tile_distribution
());
auto
w_f
=
[
&
](
auto
idx
)
{
w_
(
idx
)
=
type_convert
<
WeightType
>
(
x
(
idx
));
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
w_
.
get_tile_distribution
(),
idx
);
const
auto
current_expert
=
x_indices
.
at
(
number
<
1
>
{});
w_
(
idx
)
=
current_expert
>=
experts
?
-
numeric
<
WeightType
>::
infinity
()
:
w_
(
idx
);
};
tile_sweeper
ts
{
w_
,
w_f
};
ts
();
return
w_
;
#endif
}();
// softmax
auto
y
=
softmax
(
w
);
topk
(
y
,
out_win
,
idx_win
,
k
);
// check exit
if
constexpr
(
Problem
::
LaunchType
==
0
)
{
break
;
}
else
{
block_row_id
+=
grid_rows_per_loop
;
if
(
block_row_id
>=
rows
)
break
;
}
move_tile_window
(
inp_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
out_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
idx_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
TopkSoftmaxWarpPerRowPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
// TODO: Y dim must have one dim that is not reduced
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSoftmax
()
{
using
softmax_problem
=
BlockSoftmax2DProblem
<
typename
Problem
::
WeightType
>
;
return
BlockSoftmax2D
<
softmax_problem
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTopk
()
{
using
topk_problem
=
BlockTopkStream2DProblem
<
typename
Problem
::
WeightType
,
typename
Problem
::
IndexType
,
Problem
::
LanesPerRow
>
;
// Note: replicate is LanesPerRow
return
BlockTopkStream2D
<
topk_problem
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
InputType_
,
typename
WeightType_
,
typename
IndexType_
,
index_t
Experts_
,
index_t
IssuesPerCol_
=
2
,
// issue along col, to make sure block_reduce() OK
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
LaunchType_
=
0
,
// 0-streaming, >0, persistent #occupancy
index_t
BlockSize_
=
256
>
struct
TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using
InputType
=
remove_cvref_t
<
InputType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
LaunchType
=
LaunchType_
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static_assert
(
BytesPerIssue
%
sizeof
(
InputType
)
==
0
);
static
constexpr
index_t
VectorSize
=
BytesPerIssue
/
sizeof
(
InputType
);
static_assert
(
Experts
%
VectorSize
==
0
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static
constexpr
index_t
RowsPerWarpPerColIssue
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
RowsPerWarp
=
IssuesPerCol
*
RowsPerWarpPerColIssue
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
};
}
// namespace ck_tile
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
View file @
55cdf2b9
...
@@ -45,10 +45,10 @@ __global__ void
...
@@ -45,10 +45,10 @@ __global__ void
if
(
row_idx
<
m
&&
col_idx
<
n
)
if
(
row_idx
<
m
&&
col_idx
<
n
)
{
{
AccDataType
v_acc
=
static_cast
<
AccDataType
>
(
0.0
)
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
=
static_cast
<
ComputeTypeA
>
(
0.0
)
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
=
static_cast
<
ComputeTypeB
>
(
0.0
)
;
ComputeTypeB
v_b
{
0
}
;
CDataType
v_c
=
static_cast
<
CDataType
>
(
0.0
)
;
CDataType
v_c
{
0
}
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
{
...
@@ -76,7 +76,7 @@ __global__ void
...
@@ -76,7 +76,7 @@ __global__ void
// apply b_element_op
// apply b_element_op
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
// multiply and accumulate
// multiply and accumulate
v_acc
+=
static_cas
t
<
AccDataType
>
(
v_a
)
*
static_cas
t
<
AccDataType
>
(
v_b
);
v_acc
+=
type_conver
t
<
AccDataType
>
(
v_a
)
*
type_conver
t
<
AccDataType
>
(
v_b
);
}
}
// apply c_element_op
// apply c_element_op
c_element_op
(
v_c
,
v_acc
);
c_element_op
(
v_c
,
v_acc
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
55cdf2b9
...
@@ -96,6 +96,87 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
...
@@ -96,6 +96,87 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
@@ -155,6 +236,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -155,6 +236,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
View file @
55cdf2b9
...
@@ -141,6 +141,41 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances = std
...
@@ -141,6 +141,41 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances = std
// clang-format on
// clang-format on
>
;
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
// generic instance
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
2
>
,
// instance for small conv.K
// for bf16 conv.K and conv.C must be divisible by 2
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
2
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
2
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
8
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
8
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
8
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DynamicUnaryOp
=
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
55cdf2b9
...
@@ -373,6 +373,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -373,6 +373,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
...
@@ -483,6 +485,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -483,6 +485,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
55cdf2b9
...
@@ -89,6 +89,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
...
@@ -89,6 +89,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
NHWGC
,
...
@@ -262,6 +274,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
...
@@ -262,6 +274,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
NDHWGC
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
0 → 100644
View file @
55cdf2b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dynamic.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DynamicUnaryOp
=
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
;
#ifdef CK_ENABLE_BF16
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
F16
,
F16
,
ck
::
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
F32
,
F32
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F16
,
F16
,
ck
::
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F32
,
F32
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
ComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
,
ComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
,
ComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
&&
DLayouts
::
Size
()
==
0
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
&&
DLayouts
::
Size
()
==
0
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment