Commit cb914a54 authored by ltqin's avatar ltqin
Browse files

move z matrix pos

parent 1306ae6b
...@@ -248,12 +248,12 @@ int run(int argc, char* argv[]) ...@@ -248,12 +248,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 128;
ck::index_t N = 512; ck::index_t N = 256;
ck::index_t K = 128; ck::index_t K = 128;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t G0 = 3; ck::index_t G0 = 1;
ck::index_t G1 = 2; ck::index_t G1 = 1;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -363,7 +363,7 @@ int run(int argc, char* argv[]) ...@@ -363,7 +363,7 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{-1}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -475,6 +475,7 @@ int run(int argc, char* argv[]) ...@@ -475,6 +475,7 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
//z_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -545,6 +546,10 @@ int run(int argc, char* argv[]) ...@@ -545,6 +546,10 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
//copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data());
//std::cout << "z_g_m_n ref:\n" << z_g_m_n;
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -1441,7 +1441,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1441,7 +1441,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
if(get_thread_global_1d_id() == 0) /*if(get_thread_global_1d_id() == 0)
{ {
printf("m0: %d n0: %d m1: %d n1: %d m2: %d n2: %d n3: %d n4: %d \n", printf("m0: %d n0: %d m1: %d n1: %d m2: %d n2: %d n3: %d n4: %d \n",
m0.value, // MRepeat m0.value, // MRepeat
...@@ -1452,7 +1452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1452,7 +1452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n2.value, // NGroupNum n2.value, // NGroupNum
n3.value, // NInputNum n3.value, // NInputNum
n4.value); n4.value);
} }*/
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
...@@ -1470,6 +1470,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1470,6 +1470,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc // z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1); /*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...@@ -1482,7 +1483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1482,7 +1483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
if(get_thread_global_1d_id() == 191) /*if(get_thread_global_1d_id() == 191)
{ {
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n", printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
wave_id[I0], wave_id[I0],
...@@ -1490,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1490,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
wave_id[I2], wave_id[I2],
wave_m_n_id[I0], wave_m_n_id[I0],
wave_m_n_id[I1]); wave_m_n_id[I1]);
} }*/
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ushort,
...@@ -2116,6 +2117,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2116,6 +2117,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
if(get_thread_global_1d_id() == 1)
printf("gemm1_k_block_outer_index: %d num_gemm1_k_block_outer_loop: %d\n",
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle dQ and write // shuffle dQ and write
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment