Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b74918bc
Commit
b74918bc
authored
Jan 06, 2025
by
ThomasNing
Browse files
compiled version of cross gpu connection
parents
3fcad951
1c45ca35
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3573 additions
and
904 deletions
+3573
-904
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+122
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+287
-69
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+36
-0
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+310
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+111
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+383
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+236
-106
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+45
-20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+302
-66
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+118
-38
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+357
-312
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+10
-6
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+91
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+439
-89
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+622
-99
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+33
-29
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+57
-11
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-3
No files found.
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
namespace
ck_tile
{
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
GemmKernelArgs
=
typename
Base
::
GemmKernelArgs
;
using
ADataType
=
typename
Base
::
ADataType
;
using
BDataType
=
typename
Base
::
BDataType
;
using
CDataType
=
typename
Base
::
CDataType
;
using
TilePartitioner
=
typename
Base
::
TilePartitioner
;
using
GemmPipeline
=
typename
Base
::
GemmPipeline
;
using
EpiloguePipeline
=
typename
Base
::
EpiloguePipeline
;
using
ALayout
=
typename
Base
::
ALayout
;
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKernelArgs
MakeKernelArgs
(
const
BatchedGemmHostArgs
&
hostArgs
)
{
return
BatchedGemmKernelArgs
{{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
c_ptr
,
hostArgs
.
M
,
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
hostArgs
.
batch_count
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
b74918bc
...
...
@@ -12,6 +12,50 @@
namespace
ck_tile
{
struct
GemmProblem
{
CK_TILE_HOST
GemmProblem
()
=
default
;
CK_TILE_HOST
GemmProblem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
M
(
M_
),
N
(
N_
),
K
(
K_
),
stride_A
(
stride_A_
),
stride_B
(
stride_B_
),
stride_C
(
stride_C_
)
{
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
struct
GemmHostArgs
:
public
GemmProblem
{
CK_TILE_HOST
GemmHostArgs
()
=
default
;
CK_TILE_HOST
GemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
index_t
k_batch_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
GemmProblem
(
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
a_ptr
(
a_ptr_
),
b_ptr
(
b_ptr_
),
c_ptr
(
c_ptr_
),
k_batch
(
k_batch_
)
{
}
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
k_batch
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
...
...
@@ -25,9 +69,12 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
...
...
@@ -35,7 +82,7 @@ struct GemmKernel
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
Gemm
CommonKa
rgs
struct
Gemm
KernelA
rgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
...
...
@@ -48,36 +95,119 @@ struct GemmKernel
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmKernelArgs
{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
c_ptr
,
hostArgs
.
M
,
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_
DEVICE
void
operator
()(
GemmCommonKa
rgs
kargs
)
const
CK_TILE_
HOST
static
bool
IsSupportedArgument
(
const
GemmKernelA
rgs
&
kargs
)
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
return
true
;
}
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
)
const
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
...
...
@@ -86,7 +216,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
...
...
@@ -94,11 +224,11 @@ struct GemmKernel
}
}();
auto
b_tensor_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
...
...
@@ -107,7 +237,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
...
...
@@ -115,46 +245,11 @@ struct GemmKernel
}
}();
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_
start
,
c_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
...
...
@@ -163,7 +258,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_
start
,
c_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
...
...
@@ -171,19 +266,142 @@ struct GemmKernel
}
}();
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
// clang-format on
auto
c_block_window
=
make_tile_window
(
return
make_tuple
(
a_tensor_view
,
b_tensor_view
,
c_tensor_view
);
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
const
auto
&
b_pad_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
return
make_tuple
(
a_pad_view
,
b_pad_view
,
c_pad_view
);
}
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
const
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
b74918bc
...
...
@@ -35,4 +35,40 @@ struct GemmTilePartitioner
return
make_tuple
(
iM
,
iN
);
}
};
template
<
typename
BlockGemmShape_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
{
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
CK_TILE_DEVICE
auto
operator
()(
index_t
blockOffset
,
index_t
NBlockSize
)
{
index_t
iM
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
/
GetNBlock
(
NBlockSize
)
*
MPerBlock
);
index_t
iN
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
%
GetNBlock
(
NBlockSize
)
*
NPerBlock
);
return
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host.hpp"
namespace
ck_tile
{
struct
GroupedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GroupedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
GemmTransKernelArg
{
GroupedGemmHostArgs
group_karg
;
ck_tile
::
index_t
block_start
;
ck_tile
::
index_t
block_end
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
GroupedGemmHostArgs
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
:
group_karg
{
karg
},
block_start
{
bl_start
},
block_end
{
bl_end
}
{
}
};
__host__
static
size_t
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
using
Hargs
=
GroupedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
group_count
=
ck_tile
::
type_convert
<
ck_tile
::
index_t
>
(
gemm_descs
.
size
());
index_t
grid_size
=
0
;
gemm_kernel_args_
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M
;
const
index_t
N
=
gemm_descs
[
i
].
N
;
const
index_t
K
=
gemm_descs
[
i
].
K
;
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
*
1
*
1
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
auto
karg
=
GroupedGemmHostArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
return
gemm_kernel_args_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
Run
(
const
Hargs
&
kargs
,
const
index_t
block_start
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}(
block_start
,
kargs
.
N
);
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
int
group_count
)
const
{
const
index_t
block_id
=
ck_tile
::
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmTransKernelArg
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
&&
left
<=
right
)
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
Run
(
gemm_desc_ptr
[
group_id
].
group_karg
,
gemm_desc_ptr
[
group_id
].
block_start
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
>
struct
GemmPipelineAgBgCrImplBase
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
return
make_tuple
(
std
::
move
(
a_lds_block
),
std
::
move
(
b_lds_block
));
}
template
<
typename
ADramBlockWindowTmp
,
typename
ALdsTensorView
>
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
{
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
a_copy_dram_window
),
std
::
move
(
a_copy_lds_window
),
std
::
move
(
a_lds_gemm_window
));
}
template
<
typename
BDramBlockWindowTmp
,
typename
BLdsTensorView
>
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
{
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
b_copy_dram_window
),
std
::
move
(
b_copy_lds_window
),
std
::
move
(
b_lds_gemm_window
));
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
0 → 100644
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV3
{
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto
&&
[
a_lds_block
,
b_lds_block
]
=
Base
::
GetABLdsTensorViews
(
p_smem
);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
&&
[
a_copy_dram_window
,
a_copy_lds_window
,
a_lds_gemm_window
]
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
&&
[
b_copy_dram_window
,
b_copy_lds_window
,
b_lds_gemm_window
]
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
ABlockTile
a_block_tile
;
BBlockTile
b_block_tile
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
b74918bc
...
...
@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
...
...
@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
...
...
@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
...
...
@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
...
...
@@ -185,70 +162,42 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
...
...
@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
...
...
@@ -290,24 +239,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
...
...
@@ -318,27 +267,208 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
//
block_gemm.LocalPrefetch();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Interwave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
b74918bc
...
...
@@ -11,6 +11,7 @@ namespace ck_tile {
enum
struct
GemmPipelineScheduler
{
Default
,
Intrawave
,
Interwave
,
};
...
...
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{
switch
(
s
)
{
case
ck_tile
::
GemmPipelineScheduler
::
Default
:
os
<<
"Default"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
b74918bc
...
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
...
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B DRAM tile window for load
auto
b_copy_dram_window
=
...
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
...
...
@@ -130,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
...
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
index_t
iCounter
=
num_loop
-
1
;
...
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
b74918bc
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
...
...
@@ -11,6 +12,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
template <typename Problem>
...
...
@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
...
...
@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
// fake XOR
template
<
typename
Problem
>
...
...
@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
constexpr
bool
TransposeC
=
false
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockUniversalGemmAsBsCr
<
Problem
,
BlockGemmPolicy
>
{};
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
b74918bc
...
...
@@ -3,40 +3,135 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
static
constexpr
int
_VectorSize
=
16
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
struct
GemmPipelineProblem
Base
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kM
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
ADataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
ADataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
ADataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentB
()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kN
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
BDataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
BDataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
static
constexpr
index_t
VectorSizeA
=
kPadA
?
1
:
_VectorSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
VectorSizeB
=
kPadB
?
1
:
_VectorSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
VectorSizeC
=
kPadC
?
1
:
_VectorSize
/
sizeof
(
CDataType
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentC
()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
std
::
min
(
BlockGemmShape
::
kN
/
N1
,
get_warp_size
());
constexpr
index_t
M0
=
get_warp_size
()
/
N2
;
constexpr
index_t
M1
=
BlockGemmShape
::
kM
/
M0
;
return
std
::
min
(
M1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
else
{
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
std
::
min
(
BlockGemmShape
::
kM
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
BlockGemmShape
::
kN
/
N0
;
return
std
::
min
(
N1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
}
static
constexpr
index_t
VectorSizeA
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadK
?
1
:
GetAlignmentA
();
}
else
{
return
kPadM
?
1
:
GetAlignmentA
();
}
}();
static
constexpr
index_t
VectorSizeB
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentB
();
}
else
{
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentC
();
}
else
{
return
kPadM
?
1
:
GetAlignmentC
();
}
}();
};
// Alias for GemmPipelineProblem
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
...
...
@@ -45,30 +140,15 @@ template <typename ADataType_,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
_VectorSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
VectorSizeB
=
kPadB
?
_VectorSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
VectorSizeC
=
kPadC
?
_VectorSize
/
sizeof
(
CDataType
)
:
1
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
b74918bc
...
...
@@ -9,12 +9,8 @@
namespace
ck_tile
{
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
...
@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
{
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
return
(
16
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
{
return
(
8
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
4
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>=
4
)
{
return
(
4
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
2
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>=
2
)
{
return
(
2
/
sizeof
(
DataType
));
}
else
{
return
1
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayoutA
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
K0
*
MLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I0
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
M0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kM
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
K1
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=kN0
constexpr
auto
mpair
=
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{},
number
<
mpair
>
{},
K1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
M0
/
mpair
>
{},
number
<
mpair
>
{},
number
<
M1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
MLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayoutB
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
K0
*
NLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
N0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kN
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
K1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=kN0
constexpr
auto
npair
=
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
K1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
template
<
typename
Problem
>
...
...
@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
A
DataType
,
typename
Problem
::
B
DataType
,
typename
Problem
::
CDataType
,
Problem
::
Block
GemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
B
DataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
B
Layout
>
;
constexpr
index_t
BlockSize
=
Problem
::
k
Block
Size
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
b74918bc
...
...
@@ -3,19 +3,23 @@
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPad
A
_
,
bool
kPad
B
_
,
bool
kPad
C
_
,
template
<
bool
kPad
M
_
,
bool
kPad
N
_
,
bool
kPad
K
_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
b74918bc
...
...
@@ -10,114 +10,150 @@
namespace
ck_tile
{
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaF16F16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaBf16Bf16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -21,9 +21,15 @@ struct WarpGemmAtrributeMfma
using
BVecType
=
typename
Impl
::
BVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -51,10 +57,13 @@ struct WarpGemmAtrributeMfma
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
Impl
{}(
c_vec
,
a_vec
,
b_vec
);
Impl
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -81,38 +90,144 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -122,10 +237,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -164,9 +302,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -194,11 +338,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -222,9 +369,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -255,12 +408,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -287,38 +443,144 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -328,10 +590,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -372,10 +658,16 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -429,8 +721,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence
<
0
,
2
>>
;
#endif
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -440,10 +735,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -483,10 +801,16 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
kKPerThread
=
Impl
::
kABKPerLane
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -518,8 +842,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -529,10 +856,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,12 +7,68 @@
namespace
ck_tile
{
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum
class
WGAttrCtlEnum
{
Default_
=
0
,
Raw_vvv
=
1
,
// c-vgpr, a-vgpr, b-vgpr
Raw_vaa
=
2
,
// c-vgpr, a-agpr, b-agpr
Raw_vav
=
3
,
// c-vgpr, a-agpr, b-vgpr
Raw_vva
=
4
,
// c-vgpr, a-vgpr, b-agpr
Raw_avv
=
5
,
// c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
...
@@ -22,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -33,16 +92,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_32x32x8f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
...
...
@@ -52,18 +118,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
...
@@ -73,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -84,16 +155,151 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_16x16x16f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
...
...
@@ -101,7 +307,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_
16x16x16
f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_
4x4x4
f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
...
...
@@ -111,11 +317,13 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
};
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
...
@@ -125,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -136,28 +347,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_32x32x8bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
...
...
@@ -181,18 +399,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
});
return
c_vec
;
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
...
@@ -202,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -213,28 +436,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_16x16x16bf16_1k"
,
Ctrl
)
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
...
...
@@ -257,6 +486,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
0
);
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
...
...
@@ -266,12 +623,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
};
// FP8
template
<
typename
AType_
,
typename
BType_
>
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
using
ADataType
=
AType_
;
using
BDataType
=
BType_
;
using
CDataType
=
float
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
AType_
;
using
BDataType
=
BType_
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
...
...
@@ -281,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -292,38 +653,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
}
else
{
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
...
...
@@ -356,20 +799,100 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
});
return
c_vec
;
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
,
Ctrl_
>
;
// int8
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int32_t
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
CVecType
=
ext_vector_t
<
CDataType
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
8
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_i32_32x32x16_i8"
,
Ctrl
)
else
{
#if defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
CVecType
c_vec
{
0
};
operator
()(
c_vec
,
a_vec
,
b_vec
);
return
c_vec
;
}
};
#undef DISPATCH_MFMA_
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -21,40 +21,44 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
}
// namespace impl
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
b74918bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,6 +14,11 @@ struct WarpGemmImpl
static
constexpr
index_t
kM
=
WarpGemmAttribute
::
kM
;
static
constexpr
index_t
kN
=
WarpGemmAttribute
::
kN
;
static
constexpr
index_t
kK
=
WarpGemmAttribute
::
kK
;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static
constexpr
index_t
kKPerThread
=
WarpGemmAttribute
::
kKPerThread
;
using
ADataType
=
typename
WarpGemmAttribute
::
ADataType
;
using
BDataType
=
typename
WarpGemmAttribute
::
BDataType
;
...
...
@@ -31,11 +36,21 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_DEVICE
void
operator
()(
CWarpTensor
&
c
,
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
get_num_of_access
()
{
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
return
WarpGemmAttribute_
::
get_num_of_access
();
}
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
...
...
@@ -44,18 +59,49 @@ struct WarpGemmImpl
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
CK_TILE_DEVICE
auto
operator
()(
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
number
<
i_subk
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
CWarpTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
const
auto
a_vec
=
a
.
get_thread_buffer
().
template
get_as
<
AVec
>()[
I0
];
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
number
<
i_subk
>
{},
bool_constant
<
post_nop_
>
{});
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
CTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
b74918bc
...
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
b74918bc
...
...
@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
if
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
sqrt
(
v_
+
epsilon
);
}
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment