Commit bc9e2f25 authored by ltqin's avatar ltqin
Browse files

host dV correct

parent 4e79cc4b
...@@ -44,6 +44,7 @@ Kernel outputs: ...@@ -44,6 +44,7 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -187,12 +188,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe ...@@ -187,12 +188,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
typename TensorV, typename TensorV,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorZ,
typename TensorY, typename TensorY,
typename TensorLSE = TensorP> typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
...@@ -202,7 +207,10 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -202,7 +207,10 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
TensorY& y_g_m_o, TensorY& y_g_m_o,
TensorLSE& lse_g_m) TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -230,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -230,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V // P_dropout
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// std::cout << "p_drop_g_m_n ref:\n" << p_drop_g_m_n;
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{}); p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
} }
...@@ -263,6 +278,7 @@ int run(int argc, char* argv[]) ...@@ -263,6 +278,7 @@ int run(int argc, char* argv[])
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -433,6 +449,7 @@ int run(int argc, char* argv[]) ...@@ -433,6 +449,7 @@ int run(int argc, char* argv[])
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
...@@ -447,7 +464,17 @@ int run(int argc, char* argv[]) ...@@ -447,7 +464,17 @@ int run(int argc, char* argv[])
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m); run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits);
y_gs_ms_os.ForEach( y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
...@@ -551,6 +578,23 @@ int run(int argc, char* argv[]) ...@@ -551,6 +578,23 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
...@@ -560,6 +604,7 @@ int run(int argc, char* argv[]) ...@@ -560,6 +604,7 @@ int run(int argc, char* argv[])
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
...@@ -581,18 +626,24 @@ int run(int argc, char* argv[]) ...@@ -581,18 +626,24 @@ int run(int argc, char* argv[])
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T // dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dP = dY * V^T\n"; std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_drop_g_m_o ref:\n" << pgrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n; std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n; std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
} }
#endif #endif
// dP = dP_dropout . Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
...@@ -616,22 +667,30 @@ int run(int argc, char* argv[]) ...@@ -616,22 +667,30 @@ int run(int argc, char* argv[])
} }
#endif #endif
// dV = P^T * dY // dV = rp_dropout * P_drop^T * dY
auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1}); auto pdrop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{pdrop_g_n_m,
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o,
vgrad_g_n_o,
PassThrough{},
PassThrough{},
Scale{rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dV = P^T * dY\n"; std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m; std::cout << "pdrop_g_n_m ref:\n" << pdrop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o; std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
} }
#endif #endif
// dQ = alpha * dS * K // dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_m_n,
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); k_g_n_k,
qgrad_g_m_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dQ = alpha * dS * K\n"; std::cout << "===== dQ = alpha * dS * K\n";
...@@ -643,8 +702,12 @@ int run(int argc, char* argv[]) ...@@ -643,8 +702,12 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q // dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{sgrad_g_n_m,
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); q_g_m_k,
kgrad_g_n_k,
PassThrough{},
PassThrough{},
Scale{alpha * rp_dropout}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dK = alpha * dS^T * Q\n"; std::cout << "===== dK = alpha * dS^T * Q\n";
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename RefDataType, typename InDataType, typename OutDataType>
struct ReferenceDropout : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout)
: ref_(ref), in_(in), out_(out), p_dropout_(p_dropout)
{
}
const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
RefDataType p_dropout_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.ref_(idx) < arg.p_dropout_ ? arg.in_(idx) : 0;
});
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<RefDataType>& ref,
const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
RefDataType p_dropout)
{
return Argument{ref, in, out, p_dropout};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceDropout"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment