Commit 72c9f129 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 241c261f ded0d83d
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.6.2
rocm-docs-core==1.7.2
# via -r requirements.in
six==1.16.0
# via pybtex
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......@@ -7,7 +7,7 @@
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using CDataType = ck::half_t;
using CDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol()
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
return 2e-1;
}
else
{
......@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol()
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
return 2e-1;
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
......@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else
{
// for testing half_t
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing float
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing double
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing bhalf_t
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int8_t
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation
pass =
pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int4_t using MAX operation
pass =
pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -38,7 +38,8 @@ struct ReduceShape
static constexpr ck::index_t NumReduceDim_ = NumReduceDim;
};
using reduce_shape_instances = std::tuple<ReduceShape<3, 1>,
using reduce_shape_instances = std::tuple<ReduceShape<12, 3>,
ReduceShape<3, 1>,
ReduceShape<3, 2>,
ReduceShape<4, 1>,
ReduceShape<4, 2>,
......
......@@ -23,12 +23,8 @@
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......
......@@ -3,6 +3,7 @@ add_subdirectory(convinvscale)
add_subdirectory(convscale)
add_subdirectory(convscale_relu)
add_subdirectory(convscale_add)
add_subdirectory(convscale_reduce)
add_subdirectory(multi_AB)
add_subdirectory(unary)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_activ_xdl_convscale_reduce)
add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8)
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/type.hpp"
namespace ew = ck::tensor_operation::element_wise;
using PassThrough = ew::PassThrough;
using ConvScaleRelu = ew::UnaryCombinedOp<ew::Scale, ew::Scale, ew::Relu>;
using ConvScale = ew::UnaryCombinedOp<ew::Scale, ew::Scale, PassThrough>;
using UnaryScaleConvert = ew::Scale;
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename ConvOutDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename ConvElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<ConvOutDataType> host_conv(out_g_n_k_wos_desc);
Tensor<ConvOutDataType> device_conv(out_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
case 11: // used for debugging
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem conv_device_buf(conv_param.GetOutputByte<ConvOutDataType>());
DeviceMem out_device_buf(conv_param.GetOutputByte<OutDataType>());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// random scale values
float scale_in = float(std::rand()) / float(RAND_MAX);
float scale_wei = float(std::rand()) / float(RAND_MAX);
float scale_out = float(std::rand()) / float(RAND_MAX);
std::cout << std::endl;
std::cout << "scale_in: " << scale_in << std::endl;
std::cout << "scale_wei: " << scale_wei << std::endl;
std::cout << "scale_out: " << scale_out << std::endl;
// convolution elementwise operation
auto conv_element_op = ConvElementOp{ew::Scale{scale_in}, ew::Scale{scale_wei}, {}};
auto scale_convert = UnaryScaleConvert{scale_out}; // elementwise scale and type cast
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto conv_invoker = conv.MakeInvoker();
auto conv_argument =
conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
conv_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
conv_element_op);
if(!conv.IsSupportedArgument(conv_argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
std::string kernels = conv.GetTypeString();
float avg_time = conv_invoker.Run(conv_argument, StreamConfig{nullptr, time_kernel});
using DeviceElementwiseScale = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ConvOutDataType>, // InDataTypeTuple
ck::Tuple<OutDataType>, // OutDataTypeTuple
UnaryScaleConvert, // UnaryScaleConvert
NDimSpatial + 3, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
auto device_ew_scale = DeviceElementwiseScale{};
auto scale_invoker = device_ew_scale.MakeInvoker();
auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths,
{e_g_n_k_wos_strides},
{e_g_n_k_wos_strides},
{conv_device_buf.GetDeviceBuffer()},
{out_device_buf.GetDeviceBuffer()},
scale_convert);
if(!device_ew_scale.IsSupportedArgument(scale_argument))
{
throw std::runtime_error(
"wrong! DeviceElementwiseScale with the specified compilation parameters does "
"not support this problem");
}
kernels += std::string("\n\t\t ") + device_ew_scale.GetTypeString();
avg_time += scale_invoker.Run(scale_argument, StreamConfig{nullptr, time_kernel});
constexpr auto ReduceOpId = ck::ReduceTensorOp::AMAX;
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename ck::reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename ck::reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<ConvOutDataType,
ConvOutDataType,
ConvOutDataType,
NDimSpatial + 3,
NDimSpatial + 3,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
true, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
256, // BlockSize
4, // MThreadClusterSize
64, // KThreadClusterSize
1, // MThreadSliceSize
1, // KThreadSliceSize
1, // InSrcVectorDim
1, // InSrceVectorSize
1>; // OutDstVectorSize
std::vector<size_t> outLengths = {1};
Tensor<ConvOutDataType> amax_host(outLengths);
Tensor<ConvOutDataType> amax_from_device(outLengths);
auto amax_host_strides = amax_host.mDesc.GetStrides();
std::array<int, NDimSpatial + 3> reduce_dims;
std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // 0,..., NDimSpatial+3-1
std::array<ck::index_t, 1> reduce_out_lengths{1};
std::array<ck::index_t, 1> reduce_out_strides{static_cast<ck::index_t>(amax_host_strides[0])};
DeviceMem amax_device(sizeof(ConvOutDataType) * amax_host.mDesc.GetElementSpaceSize());
DeviceMem index_device;
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
std::tie(in_elementwise_op, acc_elementwise_op) =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(host_conv.mDesc.GetElementSize()));
// Hack convolution output strides for reduction as kernel expects stride 1 for the last
// dimension. It only works because the reduction is done on the whole tensor and result is
// independent of the order of elements.
std::array<ck::index_t, NDimSpatial + 3> reduction_strides{};
copy(HostTensorDescriptor(e_g_n_k_wos_lengths).GetStrides(), reduction_strides);
auto device_reduce = DeviceReduceInstance{};
auto reduce_invoker = device_reduce.MakeInvokerPointer();
auto reduce_argument = device_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths,
reduction_strides,
reduce_out_lengths,
reduce_out_strides,
reduce_dims,
1.0,
0.0,
conv_device_buf.GetDeviceBuffer(),
nullptr,
amax_device.GetDeviceBuffer(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!device_reduce.IsSupportedArgument(reduce_argument.get()))
{
throw std::runtime_error(
"wrong! DeviceReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!");
};
kernels += std::string("\n\t\t ") + device_reduce.GetTypeString();
float reduce_time =
reduce_invoker->Run(reduce_argument.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "\nReduce time: " << reduce_time << " ms" << std::endl;
avg_time += reduce_time;
std::size_t flop = conv_param.GetFlops(); // convolution FLOPs
auto conv_out_elems = host_conv.GetElementSize(); // number of elements in conv result tensor
// 3 element-wise scale multipliers + 1 AMAX
std::size_t elementwise_ops = 3 + 1;
if constexpr(ck::is_same_v<ConvElementOp, ConvScaleRelu>)
{
elementwise_ops += 1; // +1 element-wise relu
}
flop += elementwise_ops * conv_out_elems;
// convolution + elementwise scaling (in + wei + output byte count)
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, ConvOutDataType>();
num_btype += sizeof(float) + sizeof(float); // + 2 scales
// elementwise scaling + F8 conversion
num_btype += conv_param.GetOutputByte<ConvOutDataType>() + sizeof(float) +
conv_param.GetOutputByte<OutDataType>();
// AMAX
num_btype += conv_param.GetOutputByte<ConvOutDataType>() + sizeof(float);
if(time_kernel)
{
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << std::endl;
}
std::cout << "\nKernels: " << kernels << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
ConvOutDataType,
InElementOp,
WeiElementOp,
ConvElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
host_conv,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
conv_element_op);
ref_invoker.Run(ref_argument);
conv_device_buf.FromDevice(device_conv.mData.data());
out_device_buf.FromDevice(out_device.mData.data());
out_host.ForEach([&](auto&, auto idx) { scale_convert(out_host(idx), host_conv(idx)); });
std::cout << "\nComparing output to reference: " << std::endl;
auto tight_tol_check = ck::utils::check_err(out_device, out_host, "Error: ");
if(!tight_tol_check)
{
std::cout << "\n\tRecompare applying tolerances...\n";
std::cout << "\t\trtol = " << get_rtol<OutDataType>() << std::endl;
std::cout << "\t\tatol = " << get_atol<OutDataType>() << std::endl;
auto loose_tol_check = ck::utils::check_err(out_device,
out_host,
"Error: incorrect convolution results!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
if(!loose_tol_check)
{
return false;
}
}
std::cout << "Success!" << std::endl;
/// Verify AMAX
using RefReduceInstance =
ck::tensor_operation::host::ReferenceReduce<ConvOutDataType,
ConvOutDataType,
ConvOutDataType,
NDimSpatial + 3,
NDimSpatial + 3,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
true,
false>;
auto ref_reduce = RefReduceInstance{};
auto ref_reduce_invoker = ref_reduce.MakeInvokerPointer();
auto ref_reduce_argument = ref_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
reduce_out_lengths,
reduce_out_strides,
reduce_dims,
1.0,
0.0,
host_conv.mData.data(),
nullptr,
amax_host.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!ref_reduce.IsSupportedArgument(ref_reduce_argument.get()))
{
throw std::runtime_error(
"wrong! RefReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!");
};
ref_reduce_invoker->Run(ref_reduce_argument.get());
amax_device.FromDevice(amax_from_device.mData.data());
std::cout << "\namax: " << amax_from_device.mData[0] << std::endl;
std::cout << "amax_ref: " << amax_host.mData[0] << std::endl;
return ck::utils::check_err(amax_from_device, amax_host, "Error: incorrect AMAX results!");
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
ConvOutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScaleRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
ConvOutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<
ndim_spatial_value,
InDataType,
WeiDataType,
ConvOutDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op);
};
namespace ctc = ck::tensor_layout::convolution;
if(conv_param.num_dim_spatial_ == 1)
{
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
}
else if(conv_param.num_dim_spatial_ == 2)
{
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
}
else if(conv_param.num_dim_spatial_ == 3)
{
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
}
return true;
}
......@@ -208,6 +208,7 @@ int main(int argc, char* argv[])
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
......
......@@ -69,7 +69,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
......@@ -99,6 +99,8 @@ int main(int argc, char* argv[])
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
......@@ -109,7 +111,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
else if(argc == 12)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -123,13 +125,16 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0);
}
......@@ -212,6 +217,7 @@ int main(int argc, char* argv[])
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
......@@ -236,10 +242,12 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
......
......@@ -72,6 +72,20 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
message("removing fp8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
message("removing bf8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
......
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
endif()
endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
......@@ -17,13 +45,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
......@@ -55,10 +83,23 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# Allow comparing floating points directly in order to check sentinel values
......
......@@ -66,6 +66,34 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi"
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
......@@ -89,4 +117,4 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
\ No newline at end of file
}
......@@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
}
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
}
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
......@@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile0_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
false,
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
......@@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
{F_skpad},
{F_dpad}>>;
using fmha_bwd_dv_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
{F_skpad},
{F_dvpad}>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_pipeline_enum},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
......@@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
......@@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
......@@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait:
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize
return f'a.seqlen_q % 64 == 0'
@property
def skcheck(self) -> str:
......@@ -256,16 +273,19 @@ class FmhaBwdApiPool:
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
......@@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize:
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
F_wm1 : int # warp size along m in gemm1/gemm3
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
@dataclass
class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_pipeline : str
mask_impl : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_deterministic : str #
F_pipeline : str #
mask_impl : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = BOOL_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = DROPOUT_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property
def name(self) -> str:
......@@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_dbias == 't' : n += '_dbias'
......@@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel:
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout == 't' : n += '_dropout'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_deterministic == 't' : n += '_deterministic'
return n
@property
......@@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel:
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad)
dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_vr"]
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
return None
......@@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
......@@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl)
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
......@@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_trait_{F_idx} =
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
/* BlockSize = */ 64,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
......@@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dpad == 't' : n += 'd'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for dtype in DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
tile = d[hdim_str][0]
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k)
return gen
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
......@@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
......@@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
......@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
......@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
{F_occupancy}>;
......@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
......@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
......@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}};
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);
......@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
......@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
......@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if(s.log_level_ > 0)
std::cout
......@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
"""
@dataclass
class FmhaFwdSplitKVApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0blen : int
vlayout : str
mask : str
bias : str #
lse : str #
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@dataclass
class FmhaFwdSplitKVPipeline:
tag : str
......@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad : str #
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_pagedkv : str # t/f
F_mask : str # value from MASK_MAP
@property
......@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant'
if self.F_pagedkv == 't' : n += '_pagedkv'
return n
@dataclass
......@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
......@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
......@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
......@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
return FmhaFwdSplitKVApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
......@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
pagedkv=self.F_pipeline.F_pagedkv,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
......@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
......@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]):
if hdim == 256:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
# no need lse/paged-kv kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask))
else:
assert False
return pipelines
......@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment