Commit 4e79cc4b authored by ltqin's avatar ltqin
Browse files

save z matrix

parent cb914a54
......@@ -32,7 +32,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......
......@@ -479,7 +479,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
Scale{alpha},
QKVElementOp{},
YElementOp{});
......
......@@ -24,12 +24,13 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_MASK 0
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
......@@ -259,12 +260,12 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
float scale_rp_dropout = alpha * rp_dropout;
......@@ -333,7 +334,6 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
......@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
//z_device_buf.SetZero();
// z_device_buf.SetZero();
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
......@@ -509,11 +509,11 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{scale_rp_dropout}, //dQ *= scale_rp_dropout
Scale{scale_rp_dropout}, // dQ *= scale_rp_dropout
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long,unsigned long long>(seed,offset));
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
......@@ -543,13 +543,14 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// copy z matirx data form device
std::ofstream file("./z_matrix_txt");
z_device_buf.FromDevice(z_g_m_n.mData.data());
file << z_g_m_n << std::endl;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
if(do_verification)
{
//copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data());
//std::cout << "z_g_m_n ref:\n" << z_g_m_n;
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
......
......@@ -96,6 +96,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
......@@ -1483,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
/*if(get_thread_global_1d_id() == 191)
if(get_thread_global_1d_id() == 191)
{
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
wave_id[I0],
......@@ -1491,7 +1493,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
wave_id[I2],
wave_m_n_id[I0],
wave_m_n_id[I1]);
}*/
printf("z grid descripter{%d, %d, %d, %d, %d, %d, %d, %d, %d, %d}\n",
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9));
}
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
......@@ -1767,8 +1780,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scale = 1.0f / std::sqrt(K);
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1849,14 +1862,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
else
{
s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i];
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scale * s_slash_p_thread_buf[i]; });
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment