Commit 0fe4fb38 authored by fsx950223's avatar fsx950223
Browse files

init implementation

parents 9eb23b04 cc974f0f
...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe ...@@ -12,6 +12,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp) add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -44,6 +44,7 @@ Kernel outputs: ...@@ -44,6 +44,7 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
#define FLASH_ATTN_IMPLENTATION 0 #define FLASH_ATTN_IMPLENTATION 0
...@@ -52,6 +53,7 @@ using S = ck::Sequence<Is...>; ...@@ -52,6 +53,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -63,6 +65,7 @@ using DataType = F16; ...@@ -63,6 +65,7 @@ using DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -160,6 +163,7 @@ using DeviceGemmInstance = ...@@ -160,6 +163,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -179,9 +183,9 @@ using DeviceGemmInstance = ...@@ -179,9 +183,9 @@ using DeviceGemmInstance =
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -189,7 +193,7 @@ using DeviceGemmInstance = ...@@ -189,7 +193,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -212,7 +216,7 @@ using DeviceGemmInstance = ...@@ -212,7 +216,7 @@ using DeviceGemmInstance =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
...@@ -252,11 +256,15 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe ...@@ -252,11 +256,15 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough, PassThrough,
Scale>; Scale>;
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
typename TensorV, typename TensorV,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorZ,
typename TensorY, typename TensorY,
typename TensorLSE = TensorP> typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
...@@ -266,7 +274,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -266,7 +274,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
TensorY& y_g_m_o, TensorY& y_g_m_o,
TensorLSE& lse_g_m) TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -294,6 +306,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -294,6 +306,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P * V // Y = P * V
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
...@@ -315,6 +333,12 @@ int run(int argc, char* argv[]) ...@@ -315,6 +333,12 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 128; float K = 128;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -358,6 +382,7 @@ int run(int argc, char* argv[]) ...@@ -358,6 +382,7 @@ int run(int argc, char* argv[])
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<const void*> p_q; std::vector<const void*> p_q;
std::vector<const void*> p_k; std::vector<const void*> p_k;
std::vector<void*> p_z;
std::vector<const void*> p_v; std::vector<const void*> p_v;
std::vector<const void*> p_y; std::vector<const void*> p_y;
std::vector<const void*> p_lse; std::vector<const void*> p_lse;
...@@ -368,6 +393,7 @@ int run(int argc, char* argv[]) ...@@ -368,6 +393,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<DataType>> q_g_m_ks; std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks; std::vector<Tensor<DataType>> k_g_n_ks;
std::vector<Tensor<ZDataType>> z_g_m_ns;
std::vector<Tensor<DataType>> v_g_n_os; std::vector<Tensor<DataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns; std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns; std::vector<Tensor<DataType>> p_g_m_ns;
...@@ -376,6 +402,7 @@ int run(int argc, char* argv[]) ...@@ -376,6 +402,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<DataType>> k_tensors; std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors; std::vector<Tensor<DataType>> v_tensors;
std::vector<Tensor<DataType>> y_tensors; std::vector<Tensor<DataType>> y_tensors;
std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<DataType>> qgrad_tensors; std::vector<Tensor<DataType>> qgrad_tensors;
std::vector<Tensor<DataType>> kgrad_tensors; std::vector<Tensor<DataType>> kgrad_tensors;
...@@ -384,6 +411,7 @@ int run(int argc, char* argv[]) ...@@ -384,6 +411,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> q_tensors_device; std::vector<DeviceMemPtr> q_tensors_device;
std::vector<DeviceMemPtr> k_tensors_device; std::vector<DeviceMemPtr> k_tensors_device;
std::vector<DeviceMemPtr> z_tensors_device;
std::vector<DeviceMemPtr> v_tensors_device; std::vector<DeviceMemPtr> v_tensors_device;
std::vector<DeviceMemPtr> y_tensors_device; std::vector<DeviceMemPtr> y_tensors_device;
std::vector<DeviceMemPtr> lse_tensors_device; std::vector<DeviceMemPtr> lse_tensors_device;
...@@ -425,6 +453,11 @@ int run(int argc, char* argv[]) ...@@ -425,6 +453,11 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -438,6 +471,8 @@ int run(int argc, char* argv[]) ...@@ -438,6 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
k_gs_ns_ks_strides, k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths, v_gs_os_ns_lengths,
v_gs_os_ns_strides, v_gs_os_ns_strides,
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
...@@ -460,6 +495,7 @@ int run(int argc, char* argv[]) ...@@ -460,6 +495,7 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
...@@ -472,6 +508,7 @@ int run(int argc, char* argv[]) ...@@ -472,6 +508,7 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -535,11 +572,13 @@ int run(int argc, char* argv[]) ...@@ -535,11 +572,13 @@ int run(int argc, char* argv[])
} }
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
...@@ -547,14 +586,27 @@ int run(int argc, char* argv[]) ...@@ -547,14 +586,27 @@ int run(int argc, char* argv[])
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host( run_attention_fwd_host(q_g_m_k,
q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m); k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
...@@ -564,6 +616,7 @@ int run(int argc, char* argv[]) ...@@ -564,6 +616,7 @@ int run(int argc, char* argv[])
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
k_g_n_ks.push_back(k_g_n_k); k_g_n_ks.push_back(k_g_n_k);
z_g_m_ns.push_back(z_g_m_n);
v_g_n_os.push_back(v_g_n_o); v_g_n_os.push_back(v_g_n_o);
s_g_m_ns.push_back(s_g_m_n); s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n); p_g_m_ns.push_back(p_g_m_n);
...@@ -572,12 +625,15 @@ int run(int argc, char* argv[]) ...@@ -572,12 +625,15 @@ int run(int argc, char* argv[])
k_tensors.push_back(k_gs_ns_ks); k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns); v_tensors.push_back(v_gs_os_ns);
y_tensors.push_back(y_gs_ms_os); y_tensors.push_back(y_gs_ms_os);
z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
z_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ZDataType) * z_gs_ms_ns.GetElementSpaceSize()));
v_tensors_device.emplace_back( v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back( y_tensors_device.emplace_back(
...@@ -594,6 +650,7 @@ int run(int argc, char* argv[]) ...@@ -594,6 +650,7 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data()); k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
z_tensors_device.back()->ToDevice(z_gs_ms_ns.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data()); v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
y_tensors_device.back()->ToDevice(y_gs_ms_os.data()); y_tensors_device.back()->ToDevice(y_gs_ms_os.data());
lse_tensors_device.back()->ToDevice(lse_gs_ms.data()); lse_tensors_device.back()->ToDevice(lse_gs_ms.data());
...@@ -603,6 +660,7 @@ int run(int argc, char* argv[]) ...@@ -603,6 +660,7 @@ int run(int argc, char* argv[])
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data()); ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer()); p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer()); p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_z.push_back(z_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer()); p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer()); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
...@@ -611,8 +669,10 @@ int run(int argc, char* argv[]) ...@@ -611,8 +669,10 @@ int run(int argc, char* argv[])
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer()); p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer()); p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
} }
auto argument = gemm.MakeArgument(p_q, auto argument =
gemm.MakeArgument(p_q,
p_k, p_k,
p_z,
p_v, p_v,
p_y, p_y,
p_lse, p_lse,
...@@ -627,7 +687,9 @@ int run(int argc, char* argv[]) ...@@ -627,7 +687,9 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}); YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -640,13 +702,6 @@ int run(int argc, char* argv[]) ...@@ -640,13 +702,6 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -16,11 +16,14 @@ struct BlockwiseDropout ...@@ -16,11 +16,14 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph) __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0); return keep ? val * p_dropout_rescale : float(0);
}; };
...@@ -47,6 +50,42 @@ struct BlockwiseDropout ...@@ -47,6 +50,42 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO // #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
......
...@@ -95,6 +95,8 @@ struct Scale ...@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_; float scale_;
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename RefDataType, typename InDataType, typename OutDataType>
struct ReferenceDropout : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
float rp_dropout)
: ref_(ref),
in_(in),
out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout))
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) =
arg.ref_(idx) < arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
});
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout_in_16bits,
float rp_dropout)
{
return Argument{ref, in, out, p_dropout_in_16bits, rp_dropout};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceDropout"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -149,13 +149,6 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -149,13 +149,6 @@ struct ReferenceSoftmax : public device::BaseOperator
ck::type_convert<AccDataType>( ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) + arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx); arg.beta_ * self(idx);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
}); });
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment