Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e7cde218
Commit
e7cde218
authored
Jul 03, 2024
by
Harisankar Sadasivan
Browse files
changes suggested in PR review are made- removing comments and correcting copyright
parent
57a38a1d
Changes
56
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
96 additions
and
780 deletions
+96
-780
.pre-commit-config.yaml
.pre-commit-config.yaml
+0
-0
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
+1
-1
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+2
-14
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
...ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+7
-21
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+28
-118
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+44
-97
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
.../tensor_operation_instance/gpu/gemm_universal_streamk.hpp
+1
-452
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
...ration_instance/gpu/gemm_universal_streamk/CMakeLists.txt
+1
-65
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
...evice_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
...al_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
...l_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
...streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
..._streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
..._streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
...streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
...reamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
..._streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
...streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
...reamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
+1
-1
No files found.
.pre-commit-config.yaml
100644 → 100755
View file @
e7cde218
File mode changed from 100644 to 100755
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -117,7 +117,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
if
(
stride
==
-
1
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -162,18 +162,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
#if 0
printf
(
"B matrix:
\n
"
);
for
(
int
in
=
0
;
in
<
N
;
in
++
)
{
for
(
int
ik
=
0
;
ik
<
K
;
ik
++
)
{
printf
(
"%02x "
,
*
(
reinterpret_cast
<
uint8_t
*>
(
&
b_k_n
(
ik
,
in
))));
if
(
ik
%
8
==
7
)
printf
(
"|"
);
}
printf
(
"
\n
"
);
}
#endif
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
...
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -147,10 +147,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
// HS
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
dim3
grid_dim
;
if
(
arg
.
Grid_size
<
0
)
...
...
@@ -193,25 +191,13 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
grid_dim
,
// dim3(gdx, gdy, gdz),
dim3
(
BlockSize
),
0
,
arg_
);
stream_config
,
run_flush_cache
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg_
);
}
else
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
// dim3(gdx, gdy, gdz),
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
};
...
...
@@ -477,7 +463,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
BElementwiseOperation
,
CElementwiseOperation
)
{
// return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
// HS
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
e7cde218
...
...
@@ -1461,31 +1461,27 @@ struct BlockToCTileMap_GemmStreamK_v2
// check if there's enough work for DP+ stream-k
bool
bigEnough
=
num_tiles
>
grid_size
;
// select between
1 tile and 2 tile sk
// select between
stream-k strategies
uint32_t
sk_tiles
=
0
;
if
(
streamk_sel
==
1
)
if
(
streamk_sel
==
1
)
// 1 tile stream-k
{
sk_tiles
=
bigEnough
?
(
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
2
)
else
if
(
streamk_sel
==
2
)
// 2-tile stream-k
{
sk_tiles
=
bigEnough
?
(
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
3
)
else
if
(
streamk_sel
==
3
)
// 3-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
2
*
grid_size
))
?
(
2
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
else
if
(
streamk_sel
==
4
)
else
if
(
streamk_sel
==
4
)
// 4-tile stream-k
{
sk_tiles
=
(
num_tiles
>
(
3
*
grid_size
))
?
(
3
*
grid_size
+
num_tiles
%
grid_size
)
:
num_tiles
;
}
sk_num_blocks
=
sk_tiles
;
// if(sk_tiles < sk_num_blocks)
// {
// sk_num_blocks = sk_tiles;
// }
// remaining tiles are DP tiles
dp_tiles
=
bigEnough
?
(
num_tiles
-
sk_tiles
)
:
0
;
...
...
@@ -1508,7 +1504,6 @@ struct BlockToCTileMap_GemmStreamK_v2
dp_num_blocks
=
dp_tiles
;
dp_start_block_idx
=
sk_num_blocks
;
// dp_start_block_idx = ((sk_num_blocks + grid_size - 1) / grid_size) * grid_size;
}
n_tiles
=
MDiv2
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
...
...
@@ -1523,30 +1518,29 @@ struct BlockToCTileMap_GemmStreamK_v2
equiv_tiles_little
=
MDiv
(
upper_little
/
k_iters_per_tile
.
get
());
}
#if 0
printf("streamk_sel=%0d,grid_size=%0d, num_tiles:%d, dp_tiles:%d, sk_tiles:%u, "
"sk_num_blocks:%d,dp_num_blocks:%d,sk_num_big_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
" workspace(acc float):%u\n",
streamk_sel,
grid_size,
// occupancy,
// get_grid_dims(num_cu, occupancy).x,
num_tiles,
dp_tiles,
get_sk_tiles(),
sk_num_blocks,
dp_num_blocks,
sk_num_big_blocks,
sk_total_iters,
dp_start_block_idx,
k_iters_per_tile.get(),
k_iters_per_big_block,
reduction_start_block_idx,
get_workspace_size(sizeof(float)));
#endif
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
printf
(
"streamk_sel=%0d,grid_size=%0d, num_tiles:%d, dp_tiles:%d, sk_tiles:%u, "
"sk_num_blocks:%d,dp_num_blocks:%d,sk_num_big_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
" workspace(acc float):%u
\n
"
,
streamk_sel
,
grid_size
,
num_tiles
,
dp_tiles
,
get_sk_tiles
(),
sk_num_blocks
,
dp_num_blocks
,
sk_num_big_blocks
,
sk_total_iters
,
dp_start_block_idx
,
k_iters_per_tile
.
get
(),
k_iters_per_big_block
,
reduction_start_block_idx
,
get_workspace_size
(
sizeof
(
float
)));
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
...
@@ -1656,90 +1650,6 @@ struct BlockToCTileMap_GemmStreamK_v2
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
// adding gfx94x optimized
// index_t block_1d_id = tile_idx;
// const index_t N0 = n_tiles_value;
// const index_t M0 = math::integer_divide_ceil(n * m / m, MPerBlock);
// // index_t GroupNum = 8;
// // index_t M01_ = 4;
// if(M0 == 1)
// {
// return make_tuple(0, block_1d_id);
// }
// else if(N0 == 1)
// {
// return make_tuple(block_1d_id, 0);
// }
// // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
// else
// {
// const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
// const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
// auto group_id_x = block_1d_id % GroupNum;
// auto group_id_y = block_1d_id / GroupNum;
// auto remap_block_1d_id =
// group_id_x <= big_group_num
// ? group_id_x * group_size + group_id_y
// : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
// index_t idx_N0 = remap_block_1d_id % N0;
// index_t idx_M0 = remap_block_1d_id / N0;
// const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
// index_t idx_M00 = idx_M0 / M01_;
// index_t idx_M01 = idx_M0 % M01_;
// index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
// /**
// * idxN0
// *
// * |< mtx N >|
// *
// * NPerBlock NPerBlock NPerBlock NPerBlock
// * N_0 N_1 N_2 N_3
// * - |-----------|-----------|-----------|-----|-----|-
// * ^ | - - 0 |/----> 2 | | | |
// * | | | / | | | | | M_0 MPerBlock
// * | M | /| | | | | |
// * |-0---|---/-|-----|-----|-----------|-----|-----|-
// * | 1 | / | | | blockid | | |
// * idxM0 | | | / | V | 5 | | | M_1 MPerBlock
// * | - V 1 | - 3 | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * mtx M | | | | | |
// * | | | | | | M_2 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * | | | | | |
// * | | | | | | M_3 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * V | | | | | |
// * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * Example:
// * assume:
// * M0 = 5
// * N0 = 4
// * block_1d_id = 5
// * M01 = 2
// *
// * idx_N0 = 1
// * idx_M0 = 1
// * M01_adapt = 2
// * idx_M00 = 0
// * idx_M01 = 1
// * idx_N0_M01_local = 5
// * output {1, 2}
// */
// return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
// idx_N0_M01_local / M01_adapt);
//}
}
__host__
__device__
uint32_t
get_workspace_size_for_acc
(
uint32_t
acc_element_bytes
)
const
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
View file @
e7cde218
...
...
@@ -32,22 +32,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -62,8 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
...
...
@@ -71,17 +61,8 @@ __global__ void
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -155,15 +136,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// __host__ static auto CalculateGridSize(index_t M, index_t N) //, index_t KBatch)
// {
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
// // return ((Block2CTileMap::CalculateGridSize(M, N)) * KBatch);
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
// return Block2CTileMap::CalculateGridSize(M, N);
// }
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
...
...
@@ -995,10 +967,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else
{
// constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
// auto K_t = KReadVec;
// auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if
(
karg
.
K
<=
0
)
// HS
if
(
karg
.
K
<=
0
)
{
return
false
;
}
...
...
@@ -1103,10 +1073,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
// if(karg.KBatch > 1)
// {
// return false;
// }
}
// check gridwise gemm pipeline
...
...
@@ -1152,16 +1118,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
// using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
using
Block2CTileMap_streamk
=
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
StreamKReductionStrategy
::
Atomic
,
8
,
4
>
;
// HS
4
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -1177,43 +1139,39 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// Provide a value for TileSwizzleSubM_
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
// HS
uint32_t
iter_start
,
iter_end
;
// HS
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //, is_reduction_block; // HS
index_t
num_k_block_main_loop
;
// HS
problem
.
Streamk_sel
);
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
index_t
num_k_block_main_loop
;
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
{
// for(unsigned int kbatch_id = 0; kbatch_id < static_cast<unsigned
// int>(problem.KBatch);
// kbatch_id++)
is_sk_block
=
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
sk_num_blocks
;
is_dp_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
// HS
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
// HS
num_k_block_main_loop
=
iter_end
-
iter_start
;
// HS
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_2_ctile_map_streamk
.
get_current_iter_length
(
iter_start
,
iter_end
,
num_k_block_main_loop
));
// HS
uint32_t
tile_idx
,
iter_offset
;
// HS
iter_start
,
iter_end
,
num_k_block_main_loop
));
uint32_t
tile_idx
,
iter_offset
;
block_2_ctile_map_streamk
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
// HS
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
// HS
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
...
...
@@ -1237,17 +1195,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
/*+ splitk_batch_offset.a_k_split_offset*/
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
/*+ splitk_batch_offset.b_k_split_offset*/
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// const auto block_work_idx =
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
// HS
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
...
...
@@ -1260,7 +1214,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// HS
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
...
...
@@ -1298,7 +1252,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
// HS
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -1361,7 +1315,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
// HS
:AK0*KPadded/KPerBlock
KPerBlock
);
:
AK0
*
KPadded
/
KPerBlock
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
...
...
@@ -1607,7 +1561,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if
(
iter_end
<=
iter_start
)
break
;
// make sure next loop LDS is ready for use
block_sync_lds
();
// HS
block_sync_lds
();
}
}
}
...
...
@@ -1627,13 +1581,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
);
// HS
uint32_t
iter_start
,
iter_end
;
// HS
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //, is_reduction_block; // HS
index_t
num_k_block_main_loop
;
// HS
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
);
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //, is_reduction_block;
index_t
num_k_block_main_loop
;
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
...
...
@@ -1644,21 +1596,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
is_dp_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
// HS
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
// HS
num_k_block_main_loop
=
iter_end
-
iter_start
;
// HS
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
block_2_ctile_map_streamk
.
get_current_iter_length
(
iter_start
,
iter_end
,
num_k_block_main_loop
));
// HS
uint32_t
tile_idx
,
iter_offset
;
// HS
iter_start
,
iter_end
,
num_k_block_main_loop
));
uint32_t
tile_idx
,
iter_offset
;
block_2_ctile_map_streamk
.
get_tile_idx_with_offset
(
iter_end
-
1
,
tile_idx
,
iter_offset
);
// HS
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
// HS
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
...
...
@@ -1683,16 +1634,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
/*+ splitk_batch_offset.a_k_split_offset*/
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
/*+ splitk_batch_offset.b_k_split_offset*/
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// const auto block_work_idx =
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
// HS
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
...
...
@@ -1704,7 +1651,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
const
index_t
k0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// HS
__builtin_amdgcn_readfirstlane
(
iter_offset
*
AK0Number
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
...
...
@@ -1742,7 +1689,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
// HS
make_multi_index
(
k0_block_data_idx_on_grid
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -1773,7 +1720,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
// HS
make_multi_index
(
k0_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -237,306 +237,6 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// #endif
// #ifdef CK_ENABLE_FP16
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// #endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
...
...
@@ -626,158 +326,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
}
}
#endif
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
// if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<CDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<CDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// #endif
// #ifdef CK_ENABLE_FP16
// if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
// is_same_v<CDataType, bhalf_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// #endif
return
op_ptrs
;
}
};
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
View file @
e7cde218
...
...
@@ -21,70 +21,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
add_instance_library
(
device_gemm_universal_streamk_instance
${
GEMM_UNIVERSAL_STREAMK_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
View file @
e7cde218
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment