Commit 79823c87 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Update to dropout kernel/device-op/example

parent 064f596c
......@@ -29,7 +29,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/host_common_util.hpp"
//#include "ck/library/utility/host_common_util.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -316,17 +316,16 @@ using DeviceDropoutInstance =
TensorSpecB1,
TensorSpecC,
256, // BlockSize
128, // MPerBlock
64, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4>; // NXdlPerWave
2, // MXdlPerWave
1>; // NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc"
......
......@@ -3,14 +3,14 @@
int run(int argc, char* argv[])
{
using ck::host_common::dumpBufferToFile;
// using ck::host_common::dumpBufferToFile;
int init_method = 1;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 200; // 120
ck::index_t N = 200; // 1000
ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000
ck::index_t K = DIM;
ck::index_t O = DIM;
......@@ -225,7 +225,7 @@ int run(int argc, char* argv[])
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
dumpBufferToFile("forward_z.dat", z_gs_ms_ns.mData.data(), z_gs_ms_ns.mData.size());
// dumpBufferToFile("forward_z.dat", z_gs_ms_ns.mData.data(), z_gs_ms_ns.mData.size());
// do Dropout
auto dropout_op = DeviceDropoutInstance();
......@@ -245,7 +245,7 @@ int run(int argc, char* argv[])
z_device_buf_2.FromDevice(z_gs_ms_ns_2.mData.data());
dumpBufferToFile("canonic_z.dat", z_gs_ms_ns_2.mData.data(), z_gs_ms_ns_2.mData.size());
// dumpBufferToFile("canonic_z.dat", z_gs_ms_ns_2.mData.data(), z_gs_ms_ns_2.mData.size());
return ck::utils::check_integer_err(z_gs_ms_ns.mData, z_gs_ms_ns_2.mData, 1.0e-5);
}
......@@ -105,7 +105,6 @@ template <index_t NumDimG,
index_t Gemm1NPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
......
......@@ -343,6 +343,10 @@ struct GridwiseBatchedDropout
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
......@@ -352,9 +356,6 @@ struct GridwiseBatchedDropout
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
// save z to global
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment