Commit 98ccee74 authored by fsx950223's avatar fsx950223
Browse files

Merge remote-tracking branch 'origin/attn-bwd-develop' into grouped_api

parents bfa06cf2 32b03f33
...@@ -12,6 +12,7 @@ add_example_executable(example_batched_multihead_attention_forward_bf16 batched_ ...@@ -12,6 +12,7 @@ add_example_executable(example_batched_multihead_attention_forward_bf16 batched_
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp) add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -59,7 +59,6 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -59,7 +59,6 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using VElementOp = Scale;
using DataType = F16; using DataType = F16;
using GemmDataType = F16; using GemmDataType = F16;
...@@ -450,7 +449,7 @@ int run(int argc, char* argv[]) ...@@ -450,7 +449,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl; std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
...@@ -533,30 +532,8 @@ int run(int argc, char* argv[]) ...@@ -533,30 +532,8 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -574,11 +551,7 @@ int run(int argc, char* argv[]) ...@@ -574,11 +551,7 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -688,13 +661,15 @@ int run(int argc, char* argv[]) ...@@ -688,13 +661,15 @@ int run(int argc, char* argv[])
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device // copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// run fowad again for y, cause z_g_m_n update // run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
k_g_n_k, k_g_n_k,
v_g_n_o, v_g_n_o,
...@@ -710,7 +685,10 @@ int run(int argc, char* argv[]) ...@@ -710,7 +685,10 @@ int run(int argc, char* argv[])
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
// call kernel again // call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
...@@ -751,7 +729,7 @@ int run(int argc, char* argv[]) ...@@ -751,7 +729,7 @@ int run(int argc, char* argv[])
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dP = dY * V^T\n"; std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n; std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n; std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
} }
...@@ -763,7 +741,7 @@ int run(int argc, char* argv[]) ...@@ -763,7 +741,7 @@ int run(int argc, char* argv[])
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout); z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0; float ygrad_dot_y = 0;
for(int o = 0; o < O; o++) for(int o = 0; o < O; o++)
......
...@@ -349,7 +349,7 @@ int run(int argc, char* argv[]) ...@@ -349,7 +349,7 @@ int run(int argc, char* argv[])
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = true; // false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
...@@ -448,7 +448,7 @@ int run(int argc, char* argv[]) ...@@ -448,7 +448,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl; std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
...@@ -531,30 +531,8 @@ int run(int argc, char* argv[]) ...@@ -531,30 +531,8 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -572,11 +550,7 @@ int run(int argc, char* argv[]) ...@@ -572,11 +550,7 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -686,7 +660,9 @@ int run(int argc, char* argv[]) ...@@ -686,7 +660,9 @@ int run(int argc, char* argv[])
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device // copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
...@@ -708,7 +684,10 @@ int run(int argc, char* argv[]) ...@@ -708,7 +684,10 @@ int run(int argc, char* argv[])
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
// call kernel again // call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
......
...@@ -1265,7 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1265,7 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1520,7 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1520,7 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1556,7 +1557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1556,7 +1557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1958,19 +1959,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1958,19 +1959,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), static_for<0, n0, 1>{}([&](auto i) {
decltype(z_tenor_buffer), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
true>( decltype(z_tenor_buffer),
s_slash_p_thread_buf, ph, z_tenor_buffer); true,
decltype(n0),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, decltype(i)>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, z_tenor_buffer);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
......
...@@ -214,5 +214,48 @@ check_err(const Range& out, ...@@ -214,5 +214,48 @@ check_err(const Range& out,
return res; return res;
} }
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, unsigned short>,
bool>::type
check_err(const Range& out, const RefRange& ref, unsigned short atol = 1)
{
const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
unsigned short err = 0;
unsigned short max_err = std::numeric_limits<unsigned short>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const unsigned short o = *std::next(std::begin(out), i);
const unsigned short r = *std::next(std::begin(ref), i);
err = (o > r) ? o - r : r - o;
if(err > atol)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << " out[" << i << "] != ref[" << i << "]: " << o
<< " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << "max err: " << max_err << std::endl;
}
return res;
}
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment