Commit 0fe4fb38 authored by fsx950223's avatar fsx950223
Browse files

init implementation

parents 9eb23b04 cc974f0f
......@@ -12,6 +12,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -44,6 +44,7 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
#define FLASH_ATTN_IMPLENTATION 0
......@@ -52,6 +53,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -63,6 +65,7 @@ using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -160,6 +163,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -179,9 +183,9 @@ using DeviceGemmInstance =
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
......@@ -189,7 +193,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -212,7 +216,7 @@ using DeviceGemmInstance =
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
......@@ -252,11 +256,15 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough,
Scale>;
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
......@@ -266,7 +274,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m)
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
......@@ -294,6 +306,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
......@@ -315,6 +333,12 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 128;
float alpha = 1.f / std::sqrt(K);
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
bool input_permute = false;
bool output_permute = false;
......@@ -358,6 +382,7 @@ int run(int argc, char* argv[])
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<const void*> p_q;
std::vector<const void*> p_k;
std::vector<void*> p_z;
std::vector<const void*> p_v;
std::vector<const void*> p_y;
std::vector<const void*> p_lse;
......@@ -368,6 +393,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks;
std::vector<Tensor<ZDataType>> z_g_m_ns;
std::vector<Tensor<DataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns;
......@@ -376,6 +402,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors;
std::vector<Tensor<DataType>> y_tensors;
std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<DataType>> qgrad_tensors;
std::vector<Tensor<DataType>> kgrad_tensors;
......@@ -384,6 +411,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> q_tensors_device;
std::vector<DeviceMemPtr> k_tensors_device;
std::vector<DeviceMemPtr> z_tensors_device;
std::vector<DeviceMemPtr> v_tensors_device;
std::vector<DeviceMemPtr> y_tensors_device;
std::vector<DeviceMemPtr> lse_tensors_device;
......@@ -425,6 +453,11 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
......@@ -438,6 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
......@@ -460,6 +495,7 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -472,6 +508,7 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
}
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method)
{
case 0: break;
......@@ -535,11 +572,13 @@ int run(int argc, char* argv[])
}
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......@@ -547,14 +586,27 @@ int run(int argc, char* argv[])
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(
q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
......@@ -564,6 +616,7 @@ int run(int argc, char* argv[])
q_g_m_ks.push_back(q_g_m_k);
k_g_n_ks.push_back(k_g_n_k);
z_g_m_ns.push_back(z_g_m_n);
v_g_n_os.push_back(v_g_n_o);
s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n);
......@@ -572,12 +625,15 @@ int run(int argc, char* argv[])
k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns);
y_tensors.push_back(y_gs_ms_os);
z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
z_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ZDataType) * z_gs_ms_ns.GetElementSpaceSize()));
v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back(
......@@ -594,6 +650,7 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
z_tensors_device.back()->ToDevice(z_gs_ms_ns.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
y_tensors_device.back()->ToDevice(y_gs_ms_os.data());
lse_tensors_device.back()->ToDevice(lse_gs_ms.data());
......@@ -603,6 +660,7 @@ int run(int argc, char* argv[])
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_z.push_back(z_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
......@@ -611,8 +669,10 @@ int run(int argc, char* argv[])
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
}
auto argument = gemm.MakeArgument(p_q,
auto argument =
gemm.MakeArgument(p_q,
p_k,
p_z,
p_v,
p_y,
p_lse,
......@@ -627,7 +687,9 @@ int run(int argc, char* argv[])
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -640,13 +702,6 @@ int run(int argc, char* argv[])
return 0;
}
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
......@@ -16,11 +16,14 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer>
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
......@@ -47,6 +50,42 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
......
......@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
......
......@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename RefDataType, typename InDataType, typename OutDataType>
struct ReferenceDropout : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
float rp_dropout)
: ref_(ref),
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout))
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
});
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
float rp_dropout)
{
return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceDropout"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -149,13 +149,6 @@ struct ReferenceSoftmax : public device::BaseOperator
ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
});
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment