Commit 4e79cc4b authored by ltqin's avatar ltqin
Browse files

save z matrix

parent cb914a54
...@@ -32,7 +32,7 @@ template <ck::index_t... Is> ...@@ -32,7 +32,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......
...@@ -479,7 +479,7 @@ int run(int argc, char* argv[]) ...@@ -479,7 +479,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}); YElementOp{});
......
...@@ -24,12 +24,13 @@ Kernel outputs: ...@@ -24,12 +24,13 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 0
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -259,12 +260,12 @@ int run(int argc, char* argv[]) ...@@ -259,12 +260,12 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
float scale_rp_dropout = alpha * rp_dropout; float scale_rp_dropout = alpha * rp_dropout;
...@@ -333,7 +334,6 @@ int run(int argc, char* argv[]) ...@@ -333,7 +334,6 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
...@@ -475,7 +475,7 @@ int run(int argc, char* argv[]) ...@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
//z_device_buf.SetZero(); // z_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -509,11 +509,11 @@ int run(int argc, char* argv[]) ...@@ -509,11 +509,11 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{scale_rp_dropout}, //dQ *= scale_rp_dropout Scale{scale_rp_dropout}, // dQ *= scale_rp_dropout
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long,unsigned long long>(seed,offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -543,13 +543,14 @@ int run(int argc, char* argv[]) ...@@ -543,13 +543,14 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device
std::ofstream file("./z_matrix_txt");
z_device_buf.FromDevice(z_g_m_n.mData.data());
file << z_g_m_n << std::endl;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
//copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data());
//std::cout << "z_g_m_n ref:\n" << z_g_m_n;
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -96,6 +96,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -96,6 +96,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
...@@ -1483,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1483,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
/*if(get_thread_global_1d_id() == 191) if(get_thread_global_1d_id() == 191)
{ {
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n", printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
wave_id[I0], wave_id[I0],
...@@ -1491,7 +1493,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1491,7 +1493,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
wave_id[I2], wave_id[I2],
wave_m_n_id[I0], wave_m_n_id[I0],
wave_m_n_id[I1]); wave_m_n_id[I1]);
}*/ printf("z grid descripter{%d, %d, %d, %d, %d, %d, %d, %d, %d, %d}\n",
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9));
}
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ushort,
...@@ -1767,8 +1780,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1767,8 +1780,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock; const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2); const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scale = 1.0f / std::sqrt(K); const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1849,14 +1862,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1849,14 +1862,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
else else
{ {
s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i]; s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i]; }); [&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment