Commit 74bc4a95 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into letaoqin/gemm_bias_activation

parents 3b065199 29d384d0
......@@ -132,7 +132,11 @@ if(GPU_ARCHS)
unset(GPU_TARGETS CACHE)
unset(AMDGPU_TARGETS CACHE)
endif()
if(GPU_TARGETS)
set(USER_GPU_TARGETS 1)
else()
set(USER_GPU_TARGETS 0)
endif()
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
......@@ -162,7 +166,7 @@ endif()
if(GPU_ARCHS)
set(CK_GPU_TARGETS ${GPU_ARCHS})
else()
if(GPU_TARGETS)
if(USER_GPU_TARGETS)
set(CK_GPU_TARGETS ${GPU_TARGETS})
endif()
endif()
......@@ -545,7 +549,7 @@ ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library)
if(NOT GPU_ARCHS)
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
......
......@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
def prefixpath = "/opt/rocm"
// Jenkins is complaining about the render group
def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -1138,7 +1138,7 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
}
steps{
......
......@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
supported by the current compiler (this may take a long time).
Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
......
......@@ -12,12 +12,6 @@ API reference guide
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
some of the key design principles that are used to write new classes that extend CK functionality.
=================
Using CK API
=================
This section describes how to use the CK library API.
=================
CK Datatypes
=================
......
......@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
if(stride == 0)
{
// give a chance if stride is -1, return a default packed stride
// give a chance if stride is 0, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
......
......@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kTilePermute = false;
constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel =
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
......@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
kPadA,
kPadB,
kPadC>;
using CodegenGemmTraits = ck_tile::
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
......@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "device_base.hpp"
......@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC) = 0;
index_t StrideC) const = 0;
};
template <typename AElementwiseOperation,
......
......@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB,
index_t StrideC) override
index_t StrideC) const override
{
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
}
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
};
} // namespace device
......
......@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
......@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}
};
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
......@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
acc += static_cast<AccDataType>(A[row * strideA + k]) *
static_cast<AccDataType>(B[col * strideB + k]);
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
}
C[row * strideC + col] = acc; // Store as AccDataType
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
......@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
......@@ -3,5 +3,6 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
......@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
......@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
......@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::OGradDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
typename Problem::BlockFmhaShape::Gemm2WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
......@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
typename Problem::BlockFmhaShape::Gemm3WarpTile>,
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
typename Problem::BlockFmhaShape::Gemm4WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
// these are for global load
......
......@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
......@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
}
};
......@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
}
};
......@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::PDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
}
};
......
......@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
......
......@@ -11,20 +11,12 @@
namespace ck_tile {
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
......@@ -32,6 +24,10 @@ struct GemmKernel
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
......@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc);
}
};
......
......@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
......@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
using LayoutA = remove_cvref_t<typename Problem::LayoutA>;
using LayoutB = remove_cvref_t<typename Problem::LayoutB>;
using LayoutC = remove_cvref_t<typename Problem::LayoutC>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
......
......@@ -7,9 +7,9 @@
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
......
......@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV2
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
......
......@@ -7,12 +7,11 @@
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy =
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment