Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4085e3d0
Commit
4085e3d0
authored
Oct 10, 2024
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/ck_tile_universal_gemm_p1
parents
1b2bf88d
6f27bc98
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
440 additions
and
158 deletions
+440
-158
CMakeLists.txt
CMakeLists.txt
+6
-2
Jenkinsfile
Jenkinsfile
+1
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+37
-18
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+1
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+27
-7
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+171
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+82
-51
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+50
-31
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+6
-5
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+4
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+17
-27
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+27
-0
No files found.
CMakeLists.txt
View file @
4085e3d0
...
...
@@ -132,7 +132,11 @@ if(GPU_ARCHS)
unset
(
GPU_TARGETS CACHE
)
unset
(
AMDGPU_TARGETS CACHE
)
endif
()
if
(
GPU_TARGETS
)
set
(
USER_GPU_TARGETS 1
)
else
()
set
(
USER_GPU_TARGETS 0
)
endif
()
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
...
...
@@ -162,7 +166,7 @@ endif()
if
(
GPU_ARCHS
)
set
(
CK_GPU_TARGETS
${
GPU_ARCHS
}
)
else
()
if
(
GPU_TARGETS
)
if
(
USER_
GPU_TARGETS
)
set
(
CK_GPU_TARGETS
${
GPU_TARGETS
}
)
endif
()
endif
()
...
...
Jenkinsfile
View file @
4085e3d0
...
...
@@ -1138,7 +1138,7 @@ pipeline {
execute_args
=
""" cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102
;gfx1200;gfx1201
" \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
}
steps
{
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
4085e3d0
...
...
@@ -19,9 +19,12 @@ template <typename ALayout, typename BLayout, typename CLayout>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
...
...
@@ -38,27 +41,43 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
ALayout
,
BLayout
,
CLayout
,
kPadA
,
kPadB
,
kPadC
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
4085e3d0
...
...
@@ -58,7 +58,7 @@ struct thread_buffer {
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
typename
X_
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
4085e3d0
...
...
@@ -47,7 +47,13 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor
(
f_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
CDataType
*
C
,
...
...
@@ -65,18 +71,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if
(
row
<
M
&&
col
<
N
)
{
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
// Adjust indexing based on matrix layout
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
:
k
*
strideA
+
row
;
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
...
...
@@ -134,7 +154,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
...
...
include/ck_tile/ops/epilogue.hpp
View file @
4085e3d0
...
...
@@ -3,5 +3,6 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
0 → 100644
View file @
4085e3d0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
));
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
4085e3d0
...
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
...
...
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
...
...
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
4085e3d0
...
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
include/ck_tile/ops/gemm.hpp
View file @
4085e3d0
...
...
@@ -23,14 +23,15 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
4085e3d0
...
...
@@ -177,12 +177,12 @@ struct GemmKernel
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
// clang-format on
auto
CB
lock
W
indow
=
make_tile_window
(
auto
c_b
lock
_w
indow
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CB
lock
W
indow
,
c_block_tile
);
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
4085e3d0
...
...
@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
4085e3d0
...
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
4085e3d0
...
...
@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV2
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV2
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
View file @
4085e3d0
...
...
@@ -7,12 +7,11 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
4085e3d0
...
...
@@ -13,27 +13,23 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout
_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout
_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout
_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA
_
;
static
constexpr
bool
kPadB
=
kPadB
_
;
static
constexpr
bool
kPadC
=
kPadC
_
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
...
...
@@ -42,15 +38,9 @@ struct BlockGemmPipelineProblem
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
AccDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
,
typename
TileGemmTraits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
...
...
@@ -58,22 +48,22 @@ struct UniversalGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout
_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout
_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout
_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA
_
;
static
constexpr
bool
kPadB
=
kPadB
_
;
static
constexpr
bool
kPadC
=
kPadC
_
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
// TODO: what about vector load/store size? should we have template paramter for A/B/C ?
static
constexpr
index_t
AlignmentA
=
kPadA
?
VectorLoadSize
/
sizeof
(
ADataType
)
:
1
;
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
0 → 100644
View file @
4085e3d0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadC_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
LayoutA
=
LayoutA_
;
using
LayoutB
=
LayoutB_
;
using
LayoutC
=
LayoutC_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment