Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e92395d9
Commit
e92395d9
authored
Dec 27, 2024
by
coderfeli
Browse files
Merge remote-tracking branch 'origin/cka8w8_devtimer' into update_cka8w8_uc
parents
842d910e
7efafa11
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1030 additions
and
445 deletions
+1030
-445
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+69
-205
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+187
-72
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+16
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+259
-44
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+274
-0
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+4
-0
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
...ce_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
+3
-0
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+2
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
...evice_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
+15
-3
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp
...evice_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp
+24
-5
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+4
-3
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
.../include/profiler/profile_gemm_universal_batched_impl.hpp
+80
-68
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+12
-6
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+1
-1
profiler/src/profile_gemm_universal_batched.cpp
profiler/src/profile_gemm_universal_batched.cpp
+11
-9
script/process_perf_data.py
script/process_perf_data.py
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+20
-22
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+42
-0
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+5
-5
No files found.
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
e92395d9
...
...
@@ -3,90 +3,93 @@
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
namespace
ck_tile
{
struct
BatchedGemmHostArgs
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
struct
BatchedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
GemmKernelArgs
=
typename
Base
::
GemmKernelArgs
;
struct
BatchedGemmKargs
using
ADataType
=
typename
Base
::
ADataType
;
using
BDataType
=
typename
Base
::
BDataType
;
using
CDataType
=
typename
Base
::
CDataType
;
using
TilePartitioner
=
typename
Base
::
TilePartitioner
;
using
GemmPipeline
=
typename
Base
::
GemmPipeline
;
using
EpiloguePipeline
=
typename
Base
::
EpiloguePipeline
;
using
ALayout
=
typename
Base
::
ALayout
;
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmHostArgs
;
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
h
.
M
,
h
.
N
,
h
.
batch_count
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKargs
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
BatchedGemmKernelArgs
MakeKernelArgs
(
const
BatchedGemmHostArgs
&
hostArgs
)
{
Kargs
k
;
k
.
a_ptr
=
h
.
a_ptr
;
k
.
b_ptr
=
h
.
b_ptr
;
k
.
c_ptr
=
h
.
c_ptr
;
k
.
M
=
h
.
M
;
k
.
N
=
h
.
N
;
k
.
K
=
h
.
K
;
k
.
stride_A
=
h
.
stride_A
;
k
.
stride_B
=
h
.
stride_B
;
k
.
stride_C
=
h
.
stride_C
;
k
.
batch_stride_A
=
h
.
batch_stride_A
;
k
.
batch_stride_B
=
h
.
batch_stride_B
;
k
.
batch_stride_C
=
h
.
batch_stride_C
;
k
.
batch_count
=
h
.
batch_count
;
return
k
;
return
BatchedGemmKernelArgs
{{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
c_ptr
,
hostArgs
.
M
,
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
hostArgs
.
batch_count
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -94,7 +97,7 @@ struct BatchedGemmKernel
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Ka
rgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelA
rgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
...
...
@@ -102,156 +105,17 @@ struct BatchedGemmKernel
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_
start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
ADataType
*
a_
ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
e92395d9
...
...
@@ -12,6 +12,50 @@
namespace
ck_tile
{
struct
GemmProblem
{
CK_TILE_HOST
GemmProblem
()
=
default
;
CK_TILE_HOST
GemmProblem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
M
(
M_
),
N
(
N_
),
K
(
K_
),
stride_A
(
stride_A_
),
stride_B
(
stride_B_
),
stride_C
(
stride_C_
)
{
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
struct
GemmHostArgs
:
public
GemmProblem
{
CK_TILE_HOST
GemmHostArgs
()
=
default
;
CK_TILE_HOST
GemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
index_t
k_batch_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
GemmProblem
(
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
a_ptr
(
a_ptr_
),
b_ptr
(
b_ptr_
),
c_ptr
(
c_ptr_
),
k_batch
(
k_batch_
)
{
}
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
k_batch
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
...
...
@@ -25,9 +69,12 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
...
...
@@ -35,7 +82,7 @@ struct GemmKernel
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
Gemm
CommonKa
rgs
struct
Gemm
KernelA
rgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
...
...
@@ -48,25 +95,37 @@ struct GemmKernel
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmKernelArgs
{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
c_ptr
,
hostArgs
.
M
,
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
Gemm
CommonKa
rgs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
Gemm
KernelA
rgs
&
kargs
)
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -139,18 +198,16 @@ struct GemmKernel
return
true
;
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
...
...
@@ -159,7 +216,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
...
...
@@ -167,11 +224,11 @@ struct GemmKernel
}
}();
auto
b_tensor_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
...
...
@@ -180,7 +237,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
...
...
@@ -188,7 +245,35 @@ struct GemmKernel
}
}();
auto
a_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
return
make_tuple
(
a_tensor_view
,
b_tensor_view
,
c_tensor_view
);
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -204,14 +289,9 @@ struct GemmKernel
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
const
auto
&
b_pad_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -228,43 +308,8 @@ struct GemmKernel
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -280,12 +325,82 @@ struct GemmKernel
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
return
make_tuple
(
a_pad_view
,
b_pad_view
,
c_pad_view
);
}
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
const
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
e92395d9
...
...
@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaF16F16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
...
...
@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaBf16Bf16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
e92395d9
...
...
@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
...
...
@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
...
...
@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
...
@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
e92395d9
...
...
@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
...
@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// FP8
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
...
@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
e92395d9
...
...
@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
...
...
@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
View file @
e92395d9
...
...
@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
160
,
64
,
8
,
8
,
16
,
16
,
8
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
160
,
64
,
8
,
8
,
32
,
32
,
1
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
160
,
128
,
64
,
8
,
8
,
32
,
32
,
5
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
e92395d9
...
...
@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__
// Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
...
@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp
100644 → 100755
View file @
e92395d9
...
...
@@ -41,6 +41,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
@@ -49,7 +51,9 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
256
,
32
,
8
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
128
,
32
,
8
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
// clang-format on
>
;
...
...
@@ -61,14 +65,21 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
8
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
8
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
8
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
2
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
2
,
2
,
16
,
16
,
4
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
8
,
4
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
4
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
32
,
64
,
8
,
4
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
...
...
@@ -82,6 +93,7 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
8
,
4
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
64
,
8
,
4
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
8
,
4
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
8
,
4
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp
100644 → 100755
View file @
e92395d9
...
...
@@ -42,14 +42,21 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st
// Compute friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
// AGPR Spill
//
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32,
8
,
8
, 16, 16, 8, 8, S<
4, 64
, 1>, S<1, 0, 2>, S<1, 0, 2>, 2,
8
,
8
,
1
, S<
4, 64
, 1>, S<1, 0, 2>, S<1, 0, 2>, 2,
8
,
8
,
1
, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v
5
>,
// AGPR Spill when use permuted lds layout. so, use padding for these two.
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
16
,
16
,
8
,
8
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v
3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
@@ -68,15 +75,23 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
32
,
64
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
32
,
64
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
...
...
@@ -84,12 +99,16 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
64
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
>
;
}
// namespace instance
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
e92395d9
...
...
@@ -270,11 +270,12 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipStream_t
stream
;
hip_check_error
(
hipStreamCreate
(
&
stream
));
// timer of develop branch should only apply to empty hipstream
// hipStream_t stream;
// hip_check_error(hipStreamCreate(&stream));
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
stream
,
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
...
...
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
View file @
e92395d9
...
...
@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int
StrideB
,
int
StrideC
,
int
BatchCount
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
...
...
@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device op instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
unique_ptr
<
tensor_operation
::
device
::
BaseArgument
>
argument_ptr
;
// false branch for multi d dl kernel
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
,
19
,
32
,
38
};
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_count
});
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
kbatch_curr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchC
ount
;
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_c
ount
})
;
float
t
flop
s
=
st
atic_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
std
::
size_t
flop
=
st
d
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchCount
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
if
(
do_
log
)
if
(
do_
verification
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
...
...
@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std
::
cout
<<
" B = "
<<
BatchCount
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
" KBatch = "
<<
best_kbatch
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
"
TFlops, "
<<
best_gb_per_sec
<<
"
GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
...
...
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
e92395d9
...
...
@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification,
}
std
::
string
best_op_name
;
std
::
optional
<
std
::
string
>
best_op_object_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
...
...
@@ -225,7 +226,8 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
optional
<
std
::
string
>
op_obj_name
=
op_ptr
->
GetObjectName
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
...
...
@@ -251,11 +253,12 @@ bool profile_gemm_universal_impl(int do_verification,
if
(
tflops
>
best_tflops
&&
ave_time
>
1e-10
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
best_op_name
=
op_name
;
best_op_object_name
=
op_obj_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
}
else
...
...
@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification,
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
if
(
best_op_object_name
)
std
::
cout
<<
best_op_object_name
.
value
()
<<
std
::
endl
;
return
pass
;
}
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
e92395d9
...
...
@@ -77,7 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification,
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_host_results
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
ComputeDataTyp
e
max_abs_in_val
=
0.
f
;
doubl
e
max_abs_in_val
=
0.
f
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
a_m_k
.
push_back
(
...
...
profiler/src/profile_gemm_universal_batched.cpp
View file @
e92395d9
...
...
@@ -31,7 +31,7 @@ enum struct GemmDataType
int
profile_batched_gemm_universal
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
1
8
&&
argc
!=
2
1
)
if
(
argc
!=
1
9
&&
argc
!=
2
2
)
{
// clang-format off
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
...
...
@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg8 to 1
7
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
\n
"
);
printf
(
"arg8 to 1
8
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
, KBatch
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg1
8
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
19
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
0
: memory for rotating buffer (default 0, size in MB)
\n
"
);
printf
(
"arg1
9
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
20
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
1
: memory for rotating buffer (default 0, size in MB)
\n
"
);
// clang-format on
exit
(
1
);
}
...
...
@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int
n_warmup
=
1
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
if
(
argc
==
2
1
)
if
(
argc
==
2
2
)
{
n_warmup
=
std
::
stoi
(
argv
[
1
8
]);
n_iter
=
std
::
stoi
(
argv
[
19
]);
rotating
=
std
::
stoull
(
argv
[
2
0
])
*
1024
*
1024
;
n_warmup
=
std
::
stoi
(
argv
[
1
9
]);
n_iter
=
std
::
stoi
(
argv
[
20
]);
rotating
=
std
::
stoull
(
argv
[
2
1
])
*
1024
*
1024
;
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
...
...
@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const
int
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
const
int
BatchCount
=
std
::
stoi
(
argv
[
17
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
18
]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
...
...
@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_
,
StrideC_
,
BatchCount
,
KBatch
,
n_warmup
,
n_iter
,
rotating
);
...
...
script/process_perf_data.py
View file @
e92395d9
...
...
@@ -332,7 +332,7 @@ def main():
table_name
=
"ck_fmha_bwd_tflops"
tflops_base
=
get_baseline
(
table_name
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
sqlEngine
)
conn
.
close
()
#compare the results to the baseline if baseline exists
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
e92395d9
...
...
@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
);
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
...
...
@@ -185,21 +182,22 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
b_k_n_dev_buf
.
GetDeviceBuffer
(),
c_m_n_dev_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
BatchCount
};
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
ck_tile
::
BatchedGemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
StrideA
;
args
.
stride_B
=
StrideB
;
args
.
stride_C
=
StrideC
;
args
.
batch_stride_A
=
BatchStrideA
;
args
.
batch_stride_B
=
BatchStrideB
;
args
.
batch_stride_C
=
BatchStrideC
;
args
.
batch_count
=
BatchCount
;
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
...
...
test/ck_tile/gemm/CMakeLists.txt
View file @
e92395d9
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_
mem_
pipeline test_gemm_
mem_
pipeline.cpp
)
add_gtest_executable
(
test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp
)
endif
()
test/ck_tile/gemm/test_gemm_
mem_
pipeline.cpp
→
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
e92395d9
...
...
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_
mem_
pipeline_util.hpp"
#include "test_gemm_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
...
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemm
Mem
Pipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmPipeline
,
KernelTypes
);
#include "test_gemm_
mem_
pipeline_ut_cases.inc"
#include "test_gemm_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_
mem_
pipeline_ut_cases.inc
→
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
e92395d9
...
...
@@ -3,7 +3,7 @@
#pragma once
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
NotSupportedArgument
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
NotSupportedArgument
)
{
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment