Commit 88980945 authored by guangzlu's avatar guangzlu
Browse files

updated new dropout for attn fwd

parent db8018de
......@@ -103,17 +103,17 @@ using DeviceGemmInstance =
TensorSpecC,
1,
256,
256, // MPerBlock
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
......@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
......
......@@ -99,17 +99,17 @@ using DeviceGemmInstance =
TensorSpecC,
1,
256,
256, // MPerBlock
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
......@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
......
......@@ -105,8 +105,8 @@ using DeviceGemmInstance =
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -115,7 +115,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
......
......@@ -101,8 +101,8 @@ using DeviceGemmInstance =
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -111,7 +111,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
......
......@@ -9,8 +9,8 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000
ck::index_t M = 512; // 120
ck::index_t N = 512; // 1000
ck::index_t K = 64;
ck::index_t O = 64;
......@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
......
......@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.2;
float p_drop = 0.1;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
......@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1);
int K = 128;
int O = 128;
int K = 64;
int O = 64;
int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1;
......@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
......
......@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
// if(Gemm1N != K)
//{
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false;
//}
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
......@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
I1, // n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -883,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -1006,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
else
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment