Commit 8e3c6991 authored by fsx950223's avatar fsx950223
Browse files

merge updates

parents 5736b460 6fd1490b
......@@ -3,16 +3,14 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -43,23 +43,27 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using VElementOp = Scale;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -91,6 +95,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough,
PassThrough,
Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
......@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m)
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
......@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
......@@ -256,6 +276,13 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
......@@ -321,6 +348,11 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
......@@ -332,6 +364,7 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -339,10 +372,12 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method)
{
case 0: break;
......@@ -408,9 +443,11 @@ int run(int argc, char* argv[])
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
......@@ -418,12 +455,25 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
......@@ -433,6 +483,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
......@@ -443,6 +494,7 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
......@@ -450,11 +502,59 @@ int run(int argc, char* argv[])
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// get z matrix
{
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
invoker.Run(argument, StreamConfig{nullptr, false});
}
// not need output z matrix
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......@@ -468,6 +568,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
......@@ -481,15 +583,11 @@ int run(int argc, char* argv[])
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// 5 GEMM ops in total:
......@@ -511,9 +609,32 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data());
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
if(do_verification)
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
// call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
......@@ -523,6 +644,7 @@ int run(int argc, char* argv[])
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
......@@ -544,20 +666,26 @@ int run(int argc, char* argv[])
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T
// dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
}
#endif
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dP = dP_dropout x Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
......@@ -578,15 +706,14 @@ int run(int argc, char* argv[])
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
#endif
// dV = P^T * dY
auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m;
std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
......
......@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute_train.inc"
#include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute_train.inc"
#include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_grouped_gemm_scale_softmax_gemm_permute_train.inc"
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
......@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_grouped_gemm_scale_softmax_gemm_permute_train.inc"
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -84,7 +84,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
......
......@@ -88,7 +88,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
{
struct ProblemDesc
{
......
......@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -47,7 +47,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_batched_multiheadattention_forward_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -205,25 +205,25 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermuteTrain<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle;
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle"
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_grouped_multiheadattention_forward_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -197,25 +197,25 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -236,25 +236,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -705,16 +705,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_>;
kernel_grouped_multiheadattention_forward_xdl_cshuffle<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle"
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
__host__ __device__ auto Value() const { return scale_; }
float scale_;
};
......
......@@ -1169,11 +1169,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
const float p_drop,
ck::philox& ph)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -1492,8 +1496,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -1603,9 +1607,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op);
decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
scale_rp_dropout);
//
// set up Y dot dY
......@@ -1649,8 +1653,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
......@@ -1748,8 +1752,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1830,14 +1832,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
else
{
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......@@ -1847,25 +1850,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
if(is_dropout)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2225,7 +2232,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
s_element_op};
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
......@@ -83,7 +83,7 @@ template <typename FloatAB,
bool PadN,
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment