Commit 88980945 authored by guangzlu's avatar guangzlu
Browse files

updated new dropout for attn fwd

parent db8018de
...@@ -103,17 +103,17 @@ using DeviceGemmInstance = ...@@ -103,17 +103,17 @@ using DeviceGemmInstance =
TensorSpecC, TensorSpecC,
1, 1,
256, 256,
256, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
2, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
...@@ -130,11 +130,11 @@ using DeviceGemmInstance = ...@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
4, 2,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
......
...@@ -99,17 +99,17 @@ using DeviceGemmInstance = ...@@ -99,17 +99,17 @@ using DeviceGemmInstance =
TensorSpecC, TensorSpecC,
1, 1,
256, 256,
256, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
2, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
...@@ -126,11 +126,11 @@ using DeviceGemmInstance = ...@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
4, 2,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
......
...@@ -105,8 +105,8 @@ using DeviceGemmInstance = ...@@ -105,8 +105,8 @@ using DeviceGemmInstance =
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
64, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -115,7 +115,7 @@ using DeviceGemmInstance = ...@@ -115,7 +115,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -130,11 +130,11 @@ using DeviceGemmInstance = ...@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
4, 2,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
......
...@@ -101,8 +101,8 @@ using DeviceGemmInstance = ...@@ -101,8 +101,8 @@ using DeviceGemmInstance =
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
64, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -111,7 +111,7 @@ using DeviceGemmInstance = ...@@ -111,7 +111,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -126,11 +126,11 @@ using DeviceGemmInstance = ...@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
4, 2,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
......
...@@ -9,8 +9,8 @@ int run(int argc, char* argv[]) ...@@ -9,8 +9,8 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 1000; // 120 ck::index_t M = 512; // 120
ck::index_t N = 1000; // 1000 ck::index_t N = 512; // 1000
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 64; ck::index_t O = 64;
...@@ -97,7 +97,7 @@ int run(int argc, char* argv[]) ...@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
......
...@@ -10,8 +10,7 @@ int run(int argc, char* argv[]) ...@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1;
float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
...@@ -84,8 +83,8 @@ int run(int argc, char* argv[]) ...@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
int M = 128 * (rand() % 8 + 1); int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1); int N = 128 * (rand() % 8 + 1);
int K = 128; int K = 64;
int O = 128; int O = 64;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
...@@ -117,7 +116,7 @@ int run(int argc, char* argv[]) ...@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
......
...@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
// if(Gemm1N != K) if(Gemm1N != K)
//{ {
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false; return false;
//} }
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -883,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -883,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1006,10 +1006,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1006,10 +1006,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false>( false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer); acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
...@@ -1018,13 +1020,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1018,13 +1020,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
// ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, ph);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment