Commit e547c141 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 467b4e50 4cf70b36
...@@ -549,7 +549,7 @@ ENDFOREACH() ...@@ -549,7 +549,7 @@ ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library) add_subdirectory(library)
if(NOT GPU_ARCHS) if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
rocm_package_setup_component(tests rocm_package_setup_component(tests
LIBRARY_NAME composablekernel LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name PACKAGE_NAME tests # Prevent -static suffix on package name
......
...@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
...@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){ ...@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
...@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){ ...@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
...@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){ ...@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
def prefixpath = "/opt/rocm" def prefixpath = "/opt/rocm"
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
......
...@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ...@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
supported by the current compiler (this may take a long time). supported by the current compiler (this may take a long time).
Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
......
...@@ -12,12 +12,6 @@ API reference guide ...@@ -12,12 +12,6 @@ API reference guide
This document contains details of the APIs for the Composable Kernel (CK) library and introduces This document contains details of the APIs for the Composable Kernel (CK) library and introduces
some of the key design principles that are used to write new classes that extend CK functionality. some of the key design principles that are used to write new classes that extend CK functionality.
=================
Using CK API
=================
This section describes how to use the CK library API.
================= =================
CK Datatypes CK Datatypes
================= =================
......
...@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride = auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1) if(stride == 0)
{ {
// give a chance if stride is -1, return a default packed stride // give a chance if stride is 0, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return static_cast<std::size_t>(col); return static_cast<std::size_t>(col);
......
...@@ -5,3 +5,4 @@ add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permu ...@@ -5,3 +5,4 @@ add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permu
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
add_example_executable(elementwise_scale_permute_amax_2D_fp16_fp8 elementwise_scale_permute_amax_2D_fp16_fp8.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/utility/reduction_enums.hpp"
using F16 = ck::half_t;
using F32 = float;
using F8 = ck::f8_t;
using InputDataType = F16;
using ScaleDataType = F32;
using OutputDataType = F8;
static constexpr ck::index_t NumDim = 2;
constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX;
constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false;
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
struct ScalePassThrough
{
ScalePassThrough(const float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ constexpr void
operator()(OutputDataType& y0, OutputDataType& y1, const InputDataType& x0) const
{
y0 = ck::type_convert<OutputDataType>(ck::type_convert<ScaleDataType>(x0) * alpha_);
y1 = y0;
}
const ScaleDataType alpha_;
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs;
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<InputDataType>, // InDataTypeTuple
ck::Tuple<OutputDataType, OutputDataType>, // OutDataTypeTuple
ScalePassThrough, // Elementwise
NumDim, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8, 1>>; // OutScalarPerVectorSeq
using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType,
OutputDataType,
OutputDataType,
NumDim,
NumDim,
ReduceOperation,
UnaryAbs,
PassThrough,
ck::InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
1024, // BlockSize
1, // MThreadClusterSize
1024, // KThreadClusterSize
1, // MThreadSliceSize
16, // KThreadSliceSize
1, // InSrcVectorDim
16, // InSrceVectorSize
1>; // OutDstVectorSize
void reference_scale_permute_amax(Tensor<InputDataType>& input,
Tensor<OutputDataType>& host_output_scaled_casted_transposed,
Tensor<OutputDataType>& host_output_scaled_casted,
Tensor<OutputDataType>& host_output_amax,
const float scale)
{
ScalePassThrough out_element_op(scale);
const ck::index_t M = input.GetLengths()[0];
const ck::index_t K = input.GetLengths()[1];
for(ck::index_t m = 0; m < M; m++)
{
for(ck::index_t k = 0; k < K; k++)
{
OutputDataType y0, y1;
out_element_op(y0, y1, input(m, k));
host_output_scaled_casted(m, k) = y0;
host_output_scaled_casted_transposed(m, k) = y1;
const OutputDataType y_fabs =
ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0)));
host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0));
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
const float scale = 2.f;
ck::index_t M = 1024;
ck::index_t K = 1024;
if(argc == 3)
{
M = std::stoi(argv[1]);
K = std::stoi(argv[2]);
}
std::array<ck::index_t, 2> dims = {M, K};
std::array<ck::index_t, 2> in_strides = {K, 1};
std::array<ck::index_t, 2> out_strides = {1, M};
Tensor<InputDataType> input(dims, in_strides);
Tensor<OutputDataType> output_scaled_casted_transposed(dims, out_strides);
Tensor<OutputDataType> output_scaled_casted(dims, in_strides);
Tensor<OutputDataType> output_amax({1});
input.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
DeviceMem input_dev_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem output_scaled_casted_transposed_dev_buf(
sizeof(OutputDataType) * output_scaled_casted_transposed.mDesc.GetElementSpaceSize());
DeviceMem output_scaled_casted_dev_buf(sizeof(OutputDataType) *
output_scaled_casted.mDesc.GetElementSpaceSize());
DeviceMem output_amax_dev_buf(sizeof(OutputDataType) * output_amax.mDesc.GetElementSpaceSize());
input_dev_buf.ToDevice(input.mData.data());
std::array<const void*, 1> inputs = {input_dev_buf.GetDeviceBuffer()};
std::array<void*, 2> outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(),
output_scaled_casted_dev_buf.GetDeviceBuffer()};
std::cout << "Input: " << input.mDesc << std::endl;
std::cout << "Scale: " << scale << std::endl;
std::cout << "Output scaled casted transposed: " << output_scaled_casted_transposed.mDesc
<< std::endl;
std::cout << "Output scaled casted: " << output_scaled_casted.mDesc << std::endl;
std::cout << "Output amax: " << output_amax.mDesc << std::endl;
auto launch_transpose_scale = [&]() {
auto transposeScale = DeviceElementwisePermuteInstance{};
auto argument = transposeScale.MakeArgumentPointer(dims,
{in_strides},
{out_strides, in_strides},
inputs,
outputs,
ScalePassThrough{scale});
if(!transposeScale.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto transposeScale_invoker_ptr = transposeScale.MakeInvokerPointer();
return transposeScale_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
};
auto launch_reduce = [&]() {
auto reduce = DeviceReduceInstance{};
auto reduce_argument_ptr =
reduce.MakeArgumentPointer(dims,
in_strides,
{1}, // Output Lengths
{1}, // Output Strides
{0, 1}, // Reduce Dims
static_cast<double>(1.f),
static_cast<double>(0.f),
output_scaled_casted_dev_buf.GetDeviceBuffer(),
nullptr,
output_amax_dev_buf.GetDeviceBuffer(),
nullptr,
UnaryAbs{},
PassThrough{});
if(!reduce.IsSupportedArgument(reduce_argument_ptr.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto invoker_ptr = reduce.MakeInvokerPointer();
return invoker_ptr->Run(reduce_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
};
float ave_time = launch_transpose_scale();
ave_time += launch_reduce();
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
Tensor<OutputDataType> host_output_scaled_casted_transposed(dims, out_strides);
Tensor<OutputDataType> host_output_scaled_casted(dims, in_strides);
Tensor<OutputDataType> host_output_amax({1});
reference_scale_permute_amax(input,
host_output_scaled_casted_transposed,
host_output_scaled_casted,
host_output_amax,
scale);
output_scaled_casted_transposed_dev_buf.FromDevice(
output_scaled_casted_transposed.mData.data());
output_scaled_casted_dev_buf.FromDevice(output_scaled_casted.mData.data());
output_amax_dev_buf.FromDevice(output_amax.mData.data());
pass &= ck::utils::check_err(output_scaled_casted_transposed.mData,
host_output_scaled_casted_transposed.mData,
"Error: Incorrect results scaled transposed",
1e-3,
1e-3);
pass &= ck::utils::check_err(output_scaled_casted.mData,
host_output_scaled_casted.mData,
"Error: Incorrect results scaled",
1e-3,
1e-3);
pass &= ck::utils::check_err(
output_amax.mData, host_output_amax.mData, "Error: Incorrect results amax", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
...@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kTilePermute = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>; // The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
...@@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ...@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType, using CodegenGemmTraits = ck_tile::
BDataType, TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
AccDataType,
CodegenGemmShape,
kPadA,
kPadB,
kPadC>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t, invoke_gemm<ck_tile::half_t,
matrix_a_layout, matrix_a_layout,
...@@ -341,7 +360,13 @@ int main(int argc, char* argv[]) ...@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions); ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data()); c_buf.FromDevice(c_host_gpu_ref.data());
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
...@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator ...@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) = 0; index_t StrideC) const = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) const override
{ {
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
} }
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
}; };
} // namespace device } // namespace device
......
...@@ -419,6 +419,12 @@ struct UnaryAbs ...@@ -419,6 +419,12 @@ struct UnaryAbs
y = ck::math::abs(x); y = ck::math::abs(x);
}; };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
}; };
struct UnarySqrt struct UnarySqrt
......
This diff is collapsed.
...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
{ {
...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x) ...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
......
...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1] ? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0]; : a_m_k.mDesc.get_lengths()[0];
...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k)) ? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m)); : a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A, __global__ void naive_gemm_kernel(ADataType* A,
BDataType* B, BDataType* B,
CDataType* C, CDataType* C,
...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N) if(row < M && col < N)
{ {
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
acc += static_cast<AccDataType>(A[row * strideA + k]) * // Adjust indexing based on matrix layout
static_cast<AccDataType>(B[col * strideB + k]); int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
} }
C[row * strideC + col] = acc; // Store as AccDataType int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
} }
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device, DeviceMem& b_device,
DeviceMem& c_device, DeviceMem& c_device,
...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy( errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#pragma once #pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...@@ -25,10 +25,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,10 +25,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
...@@ -52,16 +53,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -52,16 +53,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
...@@ -84,16 +86,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,16 +86,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::OGradDataType, BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>, Problem::BlockFmhaShape::kK2>,
...@@ -117,16 +120,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -117,16 +120,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>, Problem::BlockFmhaShape::kK3>,
...@@ -149,16 +153,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -149,16 +153,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>, Problem::BlockFmhaShape::kK4>,
...@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
// these are for global load // these are for global load
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...@@ -75,10 +76,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,10 +76,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
...@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -199,10 +201,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -199,10 +201,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
...@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -954,10 +957,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -954,10 +957,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::PDataType, BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
...@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename Problem::OaccDataType, typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -23,12 +23,13 @@ ...@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment