Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bea1a130
Commit
bea1a130
authored
Oct 10, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/fav3_fwd_sept
parents
977265fd
6f27bc98
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
447 additions
and
142 deletions
+447
-142
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+40
-15
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+1
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+37
-10
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+171
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+82
-51
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+50
-31
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+6
-5
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+6
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+7
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+4
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+10
-7
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+27
-0
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
bea1a130
...
@@ -41,18 +41,39 @@ template <typename LayoutA,
...
@@ -41,18 +41,39 @@ template <typename LayoutA,
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
using
CodegenGemmTraits
=
ck_tile
::
BDataType
,
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
AccDataType
,
CodegenGemmShape
,
using
CodegenPipelineProblem
=
ck_tile
::
kPadA
,
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
kPadB
,
kPadC
>
;
using
CodegenGemmPipeline
=
ck_tile
::
Block
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_a_layout
,
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
bea1a130
...
@@ -58,7 +58,7 @@ struct thread_buffer {
...
@@ -58,7 +58,7 @@ struct thread_buffer {
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
typename
X_
,
template
<
typename
X_
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
bea1a130
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
BDataType
*
B
,
CDataType
*
C
,
CDataType
*
C
,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if
(
row
<
M
&&
col
<
N
)
if
(
row
<
M
&&
col
<
N
)
{
{
AccDataType
acc
=
0.0
;
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
// Adjust indexing based on matrix layout
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
:
k
*
strideA
+
row
;
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
}
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
}
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
DeviceMem
&
c_device
,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
...
...
include/ck_tile/ops/epilogue.hpp
View file @
bea1a130
...
@@ -3,5 +3,6 @@
...
@@ -3,5 +3,6 @@
#pragma once
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
0 → 100644
View file @
bea1a130
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
));
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
bea1a130
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
...
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
OGradDataType
,
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
...
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
// these are for global load
// these are for global load
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
bea1a130
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
PDataType
,
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
include/ck_tile/ops/gemm.hpp
View file @
bea1a130
...
@@ -24,12 +24,13 @@
...
@@ -24,12 +24,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_utils.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_utils.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
bea1a130
...
@@ -11,20 +11,12 @@
...
@@ -11,20 +11,12 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
GemmKernel
struct
GemmKernel
{
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
...
@@ -32,6 +24,10 @@ struct GemmKernel
...
@@ -32,6 +24,10 @@ struct GemmKernel
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
{
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
...
@@ -184,6 +180,7 @@ struct GemmKernel
...
@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
bea1a130
...
@@ -4,15 +4,15 @@
...
@@ -4,15 +4,15 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// A Tile Window: global memory
// A Tile Window: global memory
// B Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
struct
GemmPipelineAGmemBGmemCRegV1
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
ck_tile
::
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
bea1a130
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
#if 0
#if 0
// 2d
// 2d
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
bea1a130
...
@@ -4,15 +4,15 @@
...
@@ -4,15 +4,15 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// A Tile Window: global memory
// A Tile Window: global memory
// B Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV2
struct
GemmPipelineAGmemBGmemCRegV2
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
View file @
bea1a130
...
@@ -7,12 +7,11 @@
...
@@ -7,12 +7,11 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
using
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
bea1a130
...
@@ -13,20 +13,23 @@ template <typename ADataType_,
...
@@ -13,20 +13,23 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
bool
kPadA_
=
false
,
typename
TileGemmTraits_
>
bool
kPadB_
=
false
,
struct
GemmPipelineProblem
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutC
>
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
0 → 100644
View file @
bea1a130
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadC_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
LayoutA
=
LayoutA_
;
using
LayoutB
=
LayoutB_
;
using
LayoutC
=
LayoutC_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment