Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
74bc4a95
"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "1e4226634c253d84ead3017da663dedd3847309f"
Commit
74bc4a95
authored
Oct 12, 2024
by
letaoqin
Browse files
Merge branch 'develop' into letaoqin/gemm_bias_activation
parents
3b065199
29d384d0
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
445 additions
and
155 deletions
+445
-155
CMakeLists.txt
CMakeLists.txt
+7
-3
Jenkinsfile
Jenkinsfile
+5
-5
README.md
README.md
+1
-0
docs/reference/API_Reference_Guide.rst
docs/reference/API_Reference_Guide.rst
+0
-6
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+2
-2
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+40
-15
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+1
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+37
-10
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+171
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+82
-51
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+50
-31
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+6
-5
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+6
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+7
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+4
-5
No files found.
CMakeLists.txt
View file @
74bc4a95
...
@@ -132,7 +132,11 @@ if(GPU_ARCHS)
...
@@ -132,7 +132,11 @@ if(GPU_ARCHS)
unset
(
GPU_TARGETS CACHE
)
unset
(
GPU_TARGETS CACHE
)
unset
(
AMDGPU_TARGETS CACHE
)
unset
(
AMDGPU_TARGETS CACHE
)
endif
()
endif
()
if
(
GPU_TARGETS
)
set
(
USER_GPU_TARGETS 1
)
else
()
set
(
USER_GPU_TARGETS 0
)
endif
()
find_package
(
hip
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
# SWDEV-413293 and https://reviews.llvm.org/D155213
...
@@ -162,7 +166,7 @@ endif()
...
@@ -162,7 +166,7 @@ endif()
if
(
GPU_ARCHS
)
if
(
GPU_ARCHS
)
set
(
CK_GPU_TARGETS
${
GPU_ARCHS
}
)
set
(
CK_GPU_TARGETS
${
GPU_ARCHS
}
)
else
()
else
()
if
(
GPU_TARGETS
)
if
(
USER_
GPU_TARGETS
)
set
(
CK_GPU_TARGETS
${
GPU_TARGETS
}
)
set
(
CK_GPU_TARGETS
${
GPU_TARGETS
}
)
endif
()
endif
()
endif
()
endif
()
...
@@ -545,7 +549,7 @@ ENDFOREACH()
...
@@ -545,7 +549,7 @@ ENDFOREACH()
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_subdirectory
(
library
)
add_subdirectory
(
library
)
if
(
NOT GPU_ARCHS
)
if
(
NOT GPU_ARCHS
AND USER_GPU_TARGETS
)
rocm_package_setup_component
(
tests
rocm_package_setup_component
(
tests
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME tests
# Prevent -static suffix on package name
PACKAGE_NAME tests
# Prevent -static suffix on package name
...
...
Jenkinsfile
View file @
74bc4a95
...
@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
def
dockerOpts
=
"
--rm
--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def
dockerOpts
=
"--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
...
@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
...
@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
def
dockerOpts
=
"
--rm
--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def
dockerOpts
=
"--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
...
@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
...
@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
def
dockerOpts
=
"
--rm
--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def
dockerOpts
=
"--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
...
@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
...
@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
def
prefixpath
=
"/opt/rocm"
def
prefixpath
=
"/opt/rocm"
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
def
dockerOpts
=
"
--rm
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def
dockerOpts
=
"--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
...
@@ -1138,7 +1138,7 @@ pipeline {
...
@@ -1138,7 +1138,7 @@ pipeline {
execute_args
=
""" cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
execute_args
=
""" cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102
;gfx1200;gfx1201
" \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
}
}
steps
{
steps
{
...
...
README.md
View file @
74bc4a95
...
@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
...
@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
supported by the current compiler (this may take a long time).
supported by the current compiler (this may take a long time).
Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
...
...
docs/reference/API_Reference_Guide.rst
View file @
74bc4a95
...
@@ -12,12 +12,6 @@ API reference guide
...
@@ -12,12 +12,6 @@ API reference guide
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
some of the key design principles that are used to write new classes that extend CK functionality.
some of the key design principles that are used to write new classes that extend CK functionality.
=================
Using CK API
=================
This section describes how to use the CK library API.
=================
=================
CK Datatypes
CK Datatypes
=================
=================
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
74bc4a95
...
@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto
f_get_default_stride
=
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
if
(
stride
==
0
)
{
{
// give a chance if stride is
-1
, return a default packed stride
// give a chance if stride is
0
, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
static_cast
<
std
::
size_t
>
(
col
);
return
static_cast
<
std
::
size_t
>
(
col
);
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
74bc4a95
...
@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
using
CodegenGemmTraits
=
ck_tile
::
BDataType
,
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
AccDataType
,
CodegenGemmShape
,
kPadA
,
kPadB
,
kPadC
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_a_layout
,
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
74bc4a95
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
74bc4a95
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
74bc4a95
include/ck_tile/host/reference/reference_gemm.hpp
View file @
74bc4a95
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
BDataType
*
B
,
CDataType
*
C
,
CDataType
*
C
,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if
(
row
<
M
&&
col
<
N
)
if
(
row
<
M
&&
col
<
N
)
{
{
AccDataType
acc
=
0.0
;
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
// Adjust indexing based on matrix layout
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
:
k
*
strideA
+
row
;
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
}
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
}
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
DeviceMem
&
c_device
,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
...
...
include/ck_tile/ops/epilogue.hpp
View file @
74bc4a95
...
@@ -3,5 +3,6 @@
...
@@ -3,5 +3,6 @@
#pragma once
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
0 → 100644
View file @
74bc4a95
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
));
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
74bc4a95
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
...
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
OGradDataType
,
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
...
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
GemmDataType
,
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
// these are for global load
// these are for global load
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
74bc4a95
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
QDataType
,
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
using
GemmProblem
=
typename
Problem
::
PDataType
,
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
include/ck_tile/ops/gemm.hpp
View file @
74bc4a95
...
@@ -23,12 +23,13 @@
...
@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
74bc4a95
...
@@ -11,20 +11,12 @@
...
@@ -11,20 +11,12 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
GemmKernel
struct
GemmKernel
{
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
...
@@ -32,6 +24,10 @@ struct GemmKernel
...
@@ -32,6 +24,10 @@ struct GemmKernel
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
{
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
...
@@ -184,6 +180,7 @@ struct GemmKernel
...
@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
74bc4a95
...
@@ -4,15 +4,15 @@
...
@@ -4,15 +4,15 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// A Tile Window: global memory
// A Tile Window: global memory
// B Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
struct
GemmPipelineAGmemBGmemCRegV1
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
ck_tile
::
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
74bc4a95
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
#if 0
#if 0
// 2d
// 2d
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
74bc4a95
...
@@ -4,15 +4,15 @@
...
@@ -4,15 +4,15 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// A Tile Window: global memory
// A Tile Window: global memory
// B Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV2
struct
GemmPipelineAGmemBGmemCRegV2
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
View file @
74bc4a95
...
@@ -7,12 +7,11 @@
...
@@ -7,12 +7,11 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
using
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment