Commit d37c1d0b authored by guangzlu's avatar guangzlu
Browse files

dim=32 pass now

parent 5d90769e
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -730,7 +730,7 @@ int run(int argc, char* argv[]) ...@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.0; float p_drop = 0.1;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -1040,7 +1040,8 @@ int run(int argc, char* argv[]) ...@@ -1040,7 +1040,8 @@ int run(int argc, char* argv[])
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun qgrad_device_buf.SetZero(); // reset global accum buffer and rerun
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true}); float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
...@@ -1149,6 +1150,7 @@ int run(int argc, char* argv[]) ...@@ -1149,6 +1150,7 @@ int run(int argc, char* argv[])
std::ofstream fwd_file("./z_fwd_matrix_txt"); std::ofstream fwd_file("./z_fwd_matrix_txt");
fwd_file << z_fwd_gs_ms_ns << std::endl; fwd_file << z_fwd_gs_ms_ns << std::endl;
qgrad_device_buf.SetZero();
kgrad_device_buf.SetZero(); kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
......
...@@ -1525,7 +1525,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1525,7 +1525,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1561,7 +1561,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1561,7 +1561,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockId I1, // NBlockId
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1984,33 +1984,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1984,33 +1984,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto global_elem_id = auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value; // if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true, true>(
decltype(n0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer);
decltype(i)>(s_slash_p_thread_buf,
ph, z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
global_elem_id +
id_step * i.value,
z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // P_dropped
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0)); // static_for<0, n0, 1>{}([&](auto i) {
}); // blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // decltype(z_tenor_buffer),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // true,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0)); // decltype(n0),
// decltype(i)>(s_slash_p_thread_buf,
// ph,
// global_elem_id + id_step
// * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
......
...@@ -860,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -860,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -891,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -891,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1067,18 +1067,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1067,18 +1067,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("at 1 global_elem_id is %d \n", global_elem_id); // printf("at 1 global_elem_id is %d \n", global_elem_id);
// } // }
index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value; // index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false, false>(
decltype(n0), acc_thread_buf, ph, global_elem_id, z_tenor_buffer);
decltype(i)>(
acc_thread_buf, ph, global_elem_id + i.value * id_step, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1086,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1086,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // static_for<0, n0, 1>{}([&](auto i) {
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0)); // decltype(z_tenor_buffer),
}); // false,
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // decltype(n0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // decltype(i)>(
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0)); // acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment