Commit 7af1b43a authored by ltqin's avatar ltqin
Browse files

one block dv pass

parent b281dac5
......@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_backward_V1R2 batched_multihead_attention_backward_V2R2.cpp)
add_example_executable(example_batched_multihead_attention_backward_V2R2 batched_multihead_attention_backward_V2R2.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -350,9 +350,9 @@ using DeviceGemmInstance =
2, // B1K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
1, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
......@@ -375,8 +375,8 @@ using DeviceGemmInstance =
4,
2,
false,
4, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
......@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t M = 128;
ck::index_t N = 128;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_drop = 0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
"error",
1e-2,
1e-2);
//std::cout << vgrad_gs_os_ns_device_result << std::endl;
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment