Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d51701d4
Commit
d51701d4
authored
Oct 31, 2024
by
Andriy Roshchenko
Browse files
Merge remote-tracking branch 'ck_public/develop' into andriy/merge_from_public
parents
f221c2b0
c3a4800c
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1491 additions
and
10 deletions
+1491
-10
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
+81
-0
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
...de/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
+16
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+9
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+113
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
...e/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
+22
-0
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+11
-0
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+166
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+123
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+63
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
...pk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
+46
-0
include/ck_tile/ops/welford.hpp
include/ck_tile/ops/welford.hpp
+1
-0
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+5
-5
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
...library/reference_tensor_operation/gpu/reference_gemm.hpp
+5
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+105
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
...v_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
+179
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
...n_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
+278
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+127
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
...ration_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
+10
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
+99
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
+32
-0
No files found.
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
namespace
ck_tile
{
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockSoftmax2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
DistributedTensor
&
y
,
number
<
dim
>
=
{})
{
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
#if _BLOCK_SOFTMAX_USE_UNPACK2
const
auto
f_max3
=
[](
auto
e0
,
auto
e1
,
auto
e2
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, %2, %3"
:
"=v"
(
rtn
)
:
"v"
(
e0
),
"v"
(
e1
),
"v"
(
e2
));
return
rtn
;
};
const
auto
f_sum3
=
[](
auto
e0
,
auto
e1
,
auto
e2
)
{
return
e0
+
e1
+
e2
;
};
#endif
// compute row max
auto
reduce_row_max
=
BlockReduce2D
{
x
,
-
numeric
<
DataType
>::
infinity
()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto
row_max
=
reduce_row_max
(
f_max3
,
f_max
,
sequence
<
1
,
2
>
{});
#else
auto
row_max
=
reduce_row_max
(
f_max
);
#endif
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
y
(
idx
)
=
exp
(
x
[
idx
]
-
row_max
[
row_id
]);
});
// compute row sum
auto
reduce_row_sum
=
BlockReduce2D
<
decltype
(
y
)
>
{
y
,
DataType
{
0
}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto
row_sum
=
reduce_row_sum
(
f_sum3
,
f_sum
,
sequence
<
1
,
2
>
{});
#else
auto
row_sum
=
reduce_row_sum
(
f_sum
);
#endif
// reciprocal
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
sweep_tile
(
row_sum
,
[
&
](
auto
idx
)
{
r
(
idx
)
=
DataType
{
1
}
/
row_sum
(
idx
);
});
// scale
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
y
(
idx
)
=
y
(
idx
)
*
r
(
row_id
);
});
}
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
decltype
(
auto
)
operator
()(
const
DistributedTensor
&
x
,
number
<
dim
>
=
{})
{
auto
y
=
DistributedTensor
{};
// distributed tensor
operator
()(
x
,
y
,
number
<
dim
>
{});
return
y
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
>
struct
BlockSoftmax2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockTopkStream2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
using
IndexType
=
typename
Problem
::
IndexType
;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct
ArgmaxPacket
{
DataType
arg
;
index_t
value
;
};
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
const
OutWindow
&
out_window
,
const
IdxWindow
&
idx_window
,
index_t
k
,
number
<
dim
>
=
{})
{
OutWindow
out_window_tmp
=
out_window
;
IdxWindow
idx_window_tmp
=
idx_window
;
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
static_assert
(
std
::
is_same_v
<
typename
IdxWindow
::
DataType
,
IndexType
>
);
DistributedTensor
x_tmp
=
x
;
constexpr
auto
dst_dist
=
typename
IdxWindow
::
TileDstr
{};
// argmax for topk
const
auto
f_argmax
=
[](
ArgmaxPacket
e0
,
ArgmaxPacket
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
for
(
index_t
i_k
=
0
;
i_k
<
k
;
i_k
++
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
auto
packet
=
[
&
]()
{
auto
tmp
=
make_static_distributed_tensor
<
ArgmaxPacket
>
(
x
.
get_tile_distribution
());
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tmp
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
t
;
t
.
arg
=
x_tmp
(
i_j_idx
);
// !!! we reference x here
t
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
tmp
(
i_j_idx
)
=
t
;
});
});
return
tmp
;
}();
auto
argmax_init
=
ArgmaxPacket
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
ArgmaxPacket
>
(
packet
,
sequence
<
1
>
{},
f_argmax
,
argmax_init
);
block_tile_reduce_xor_sync
(
r
,
f_argmax
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
IndexType
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
// update value
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
x
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
auto
col_id
=
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
x_tmp
(
i_j_idx
)
=
(
col_id
==
r
(
i_j_idx
).
value
)
?
-
numeric
<
DataType
>::
infinity
()
:
x_tmp
(
i_j_idx
);
});
});
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
store_tile
(
out_window_tmp
,
o
);
store_tile
(
idx_window_tmp
,
i
);
}
move_tile_window
(
out_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
DataType_
,
typename
IndexType_
,
index_t
ColLanes_
>
struct
BlockTopkStream2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
ColLanes
=
ColLanes_
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
TopkSoftmaxHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
template
<
typename
Pipeline_
>
struct
TopkSoftmaxKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
InputType
=
typename
Problem
::
InputType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
IndexType
=
typename
Problem
::
IndexType
;
struct
TopkSoftmaxKargs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Hargs
=
TopkSoftmaxHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
if
constexpr
(
Problem
::
LaunchType
>
0
)
{
int
num_cu
=
[
&
]()
{
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
HIP_CHECK_ERROR
(
hipGetDevice
(
&
dev
));
HIP_CHECK_ERROR
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
return
dev_prop
.
multiProcessorCount
;
}();
return
dim3
(
num_cu
*
Problem
::
LaunchType
);
}
else
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
k
.
stride_input
=
h
.
stride_input
;
k
.
stride_output
=
h
.
stride_output
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
if
(
block_row_id
>
kargs
.
num_rows
)
return
;
index_t
block_os_inp
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_input
);
index_t
block_os_out
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_output
);
index_t
num_rows_rem
=
__builtin_amdgcn_readfirstlane
(
kargs
.
num_rows
-
block_row_id
);
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
block_os_inp
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_input
,
make_tuple
(
num_rows_rem
,
kargs
.
num_experts
),
make_tuple
(
kargs
.
stride_input
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
0
,
1
>
{});
// out-most dim no need pad(leverage oob)
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
{
0
,
0
});
}();
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_output
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_indices
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
num_rows
,
kargs
.
num_experts
,
kargs
.
topk
,
block_row_id
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
struct
TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
WeightType
=
typename
Problem
::
WeightType
;
template
<
typename
InputWindow
,
typename
OutputWindow
,
typename
IndexWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
index_t
rows
,
index_t
experts
,
index_t
k
,
index_t
block_row_id
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto
inp_win
=
make_tile_window_linear_raw
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#else
auto
inp_win
=
make_tile_window_linear
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#endif
auto
out_win
=
make_tile_window_linear
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
idx_win
=
make_tile_window_linear
(
idx_window
.
get_bottom_tensor_view
(),
idx_window
.
get_window_lengths
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
const
index_t
grid_rows_per_loop
=
gridDim
.
x
*
Problem
::
RowsPerBlock
;
while
(
1
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
#else
auto
x
=
load_tile
(
inp_win
);
#endif
// cast and pad input data
auto
w
=
[
&
]()
{
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto
w_
=
make_static_distributed_tensor
<
WeightType
>
(
x
.
get_tile_distribution
());
auto
w_f
=
[
&
](
auto
idx
)
{
w_
(
idx
)
=
type_convert
<
WeightType
>
(
x
(
idx
));
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
w_
.
get_tile_distribution
(),
idx
);
const
auto
current_expert
=
x_indices
.
at
(
number
<
1
>
{});
w_
(
idx
)
=
current_expert
>=
experts
?
-
numeric
<
WeightType
>::
infinity
()
:
w_
(
idx
);
};
tile_sweeper
ts
{
w_
,
w_f
};
ts
();
return
w_
;
#endif
}();
// softmax
auto
y
=
softmax
(
w
);
topk
(
y
,
out_win
,
idx_win
,
k
);
// check exit
if
constexpr
(
Problem
::
LaunchType
==
0
)
{
break
;
}
else
{
block_row_id
+=
grid_rows_per_loop
;
if
(
block_row_id
>=
rows
)
break
;
}
move_tile_window
(
inp_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
out_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
idx_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
TopkSoftmaxWarpPerRowPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
// TODO: Y dim must have one dim that is not reduced
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSoftmax
()
{
using
softmax_problem
=
BlockSoftmax2DProblem
<
typename
Problem
::
WeightType
>
;
return
BlockSoftmax2D
<
softmax_problem
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTopk
()
{
using
topk_problem
=
BlockTopkStream2DProblem
<
typename
Problem
::
WeightType
,
typename
Problem
::
IndexType
,
Problem
::
LanesPerRow
>
;
// Note: replicate is LanesPerRow
return
BlockTopkStream2D
<
topk_problem
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
InputType_
,
typename
WeightType_
,
typename
IndexType_
,
index_t
Experts_
,
index_t
IssuesPerCol_
=
2
,
// issue along col, to make sure block_reduce() OK
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
LaunchType_
=
0
,
// 0-streaming, >0, persistent #occupancy
index_t
BlockSize_
=
256
>
struct
TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using
InputType
=
remove_cvref_t
<
InputType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
LaunchType
=
LaunchType_
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static_assert
(
BytesPerIssue
%
sizeof
(
InputType
)
==
0
);
static
constexpr
index_t
VectorSize
=
BytesPerIssue
/
sizeof
(
InputType
);
static_assert
(
Experts
%
VectorSize
==
0
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static
constexpr
index_t
RowsPerWarpPerColIssue
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
RowsPerWarp
=
IssuesPerCol
*
RowsPerWarpPerColIssue
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/welford.hpp
View file @
d51701d4
...
...
@@ -6,4 +6,5 @@
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -276,8 +276,8 @@ struct BlockWelfordCrossWarpSync
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_
reduce_
warps
+
local_smem_os
+
i_1
];
all_scratch
[
i_0
*
num_
reduce_
warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_warps
+
local_smem_os
+
i_1
];
});
});
block_sync_lds
();
// TODO: we don't need sync here
...
...
@@ -286,7 +286,7 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_warps
];
auto
v_local
=
all_scratch
[
i_0
*
num_
reduce_
warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
...
...
@@ -294,7 +294,7 @@ struct BlockWelfordCrossWarpSync
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_warps
+
i_1
];
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_
reduce_
warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
...
...
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
View file @
d51701d4
...
...
@@ -45,10 +45,10 @@ __global__ void
if
(
row_idx
<
m
&&
col_idx
<
n
)
{
AccDataType
v_acc
=
static_cast
<
AccDataType
>
(
0.0
)
;
ComputeTypeA
v_a
=
static_cast
<
ComputeTypeA
>
(
0.0
)
;
ComputeTypeB
v_b
=
static_cast
<
ComputeTypeB
>
(
0.0
)
;
CDataType
v_c
=
static_cast
<
CDataType
>
(
0.0
)
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
{
0
}
;
CDataType
v_c
{
0
}
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
...
...
@@ -76,7 +76,7 @@ __global__ void
// apply b_element_op
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
// multiply and accumulate
v_acc
+=
static_cas
t
<
AccDataType
>
(
v_a
)
*
static_cas
t
<
AccDataType
>
(
v_b
);
v_acc
+=
type_conver
t
<
AccDataType
>
(
v_a
)
*
type_conver
t
<
AccDataType
>
(
v_b
);
}
// apply c_element_op
c_element_op
(
v_c
,
v_acc
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
d51701d4
...
...
@@ -96,6 +96,87 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
...
...
@@ -155,6 +236,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DynamicUnaryOp
=
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
ConvFwdOddC
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dynamic.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DynamicUnaryOp
=
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
;
#ifdef CK_ENABLE_BF16
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
F16
,
F16
,
ck
::
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
F32
,
F32
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
ck
::
Tuple
<>
,
NHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F16
,
F16
,
ck
::
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F32
,
F32
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
ComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
,
ComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
DynamicUnaryOp
,
ComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
&&
DLayouts
::
Size
()
==
0
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
&&
DLayouts
::
Size
()
==
0
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
d51701d4
...
...
@@ -23,6 +23,130 @@
namespace
ck
{
namespace
utils
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
numberOfAccumulations
=
1
)
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
-
NumericUtils
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
numberOfAccumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
numberOfAccumulations
=
1
)
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
numberOfAccumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
...
...
@@ -253,11 +377,13 @@ check_err(const Range& out,
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
...
...
@@ -270,6 +396,7 @@ check_err(const Range& out,
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
View file @
d51701d4
...
...
@@ -8,9 +8,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
I32
=
int
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
MultiplyMultiply
=
element_wise
::
MultiplyMultiply
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
// clang-format oI
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
32
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
16
,
128
,
16
,
16
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
64
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
// clang-format oI
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
0 → 100644
View file @
d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment