Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8ce41034
Commit
8ce41034
authored
Feb 08, 2025
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8_uc
parents
730b98e1
a8c5bd9b
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
552 additions
and
108 deletions
+552
-108
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+13
-7
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+3
-2
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+2
-1
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+4
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+5
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+68
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+236
-55
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+1
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
...rary/reference_tensor_operation/cpu/reference_mx_gemm.hpp
+178
-0
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp
...16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp
+27
-9
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+14
-16
No files found.
include/ck_tile/host/check_err.hpp
View file @
8ce41034
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
{
using
F8
=
ck_tile
::
fp8_t
;
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
{
using
F8
=
ck_tile
::
fp8_t
;
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
8ce41034
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
}
}
...
...
include/ck_tile/ops/batched_transpose.hpp
View file @
8ce41034
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
8ce41034
...
@@ -77,6 +77,7 @@ struct CShuffleEpilogue
...
@@ -77,6 +77,7 @@ struct CShuffleEpilogue
*
*
* @return The vector store size for C tensor.
* @return The vector store size for C tensor.
*/
*/
template
<
typename
ODataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
{
constexpr
index_t
MaxVectorStoreSize
=
16
;
constexpr
index_t
MaxVectorStoreSize
=
16
;
...
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
...
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D
<
kBlockSize
,
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kMPerIteration
,
kNPerIteration
,
kNPerIteration
,
GetVectorSizeC
(),
GetVectorSizeC
<
ODataType
>
(),
tile_distribution_pattern
::
thread_raked
>
;
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
8ce41034
...
@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
...
@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static
constexpr
index_t
KPack
=
8
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
8ce41034
...
@@ -159,7 +159,7 @@ struct GemmKernel
...
@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
{
if
constexpr
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
{
if
(
kargs
.
k_batch
!=
1
)
if
(
kargs
.
k_batch
!=
1
)
...
@@ -240,7 +240,7 @@ struct GemmKernel
...
@@ -240,7 +240,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
N
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
N
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -255,7 +255,7 @@ struct GemmKernel
...
@@ -255,7 +255,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
M
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
M
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -321,7 +321,7 @@ struct GemmKernel
...
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr
,
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
EpiloguePipeline
::
GetVectorSizeC
()
>
{},
number
<
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -519,7 +519,7 @@ struct GemmKernel
...
@@ -519,7 +519,7 @@ struct GemmKernel
{
{
// Do not compile in case where we have unsupported
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
8ce41034
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
#pragma once
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...
@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST
static
std
::
string
Print
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
{
{
...
@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
BlockSize
/
WaveSize
)
/
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
8ce41034
...
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
...
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
constexpr
bool
is_a_col_major
=
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// Definitions of all needed tiles
...
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// Gemm pipeline start
// prefetch
// prefetch
// global read 0
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
// main body
// main body
...
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
i
+=
PrefetchStages
;
i
+=
PrefetchStages
;
...
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
b_element_func
);
}
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
constexpr
bool
is_a_col_major
=
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// Definitions of all needed tiles
...
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// Gemm pipeline start
// prefetch
// prefetch
// global read 0
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
// main body
// main body
...
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
// no second block_sync_lds because it's interwave
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
i
+=
PrefetchStages
;
i
+=
PrefetchStages
;
...
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
// no second block_sync_lds because it's interwave
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
b_element_func
);
}
});
});
block_sync_lds
();
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
8ce41034
...
@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
...
@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
N
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
M
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
0 → 100644
View file @
8ce41034
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceMXGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_m_k_
{
a_m_k
},
a_m_kblock_scales_
{
a_m_kblock_scales
},
b_k_n_
{
b_k_n
},
b_kblock_n_scales_
{
b_kblock_n_scales
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceMXGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ComputeTypeA
,
ComputeTypeB
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
;
Tensor
<
ComputeTypeA
>
a_m_k_scaled
(
arg
.
a_m_k_
.
mDesc
);
Tensor
<
ComputeTypeB
>
b_k_n_scaled
(
arg
.
b_k_n_
.
mDesc
);
const
auto
M
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
arg
.
b_k_n_
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
auto
SCALE_BLOCK
=
K
/
arg
.
a_m_kblock_scales_
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k_scaled
(
m
,
k
)
=
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_k_
(
m
,
k
))
*
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_kblock_scales_
(
m
,
k
/
SCALE_BLOCK
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n_scaled
(
k
,
n
)
=
type_convert
<
ComputeTypeB
>
(
arg
.
b_k_n_
(
k
,
n
))
*
type_convert
<
ComputeTypeB
>
(
arg
.
b_kblock_n_scales_
(
k
/
SCALE_BLOCK
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k_scaled
,
b_k_n_scaled
,
arg
.
c_m_n_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ref_invoker
.
Run
(
ref_argument
);
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_m_k
,
a_m_kblock_scales
,
b_k_n
,
b_kblock_n_scales
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMXGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp
View file @
8ce41034
...
@@ -72,7 +72,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple
...
@@ -72,7 +72,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple
>
;
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
...
@@ -86,16 +86,34 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple
...
@@ -86,16 +86,34 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
64
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
64
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
// Memory friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
256
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
256
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
64
,
256
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
// clang-format on
>
;
>
;
#endif // defined(CK_USE_AMD_MFMA_GFX950)
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
8ce41034
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <tuple>
...
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
//
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
//
ck_tile::GemmPipelineScheduler::Interwave>;
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
//
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
>
;
// clang-format on
// clang-format on
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment