Commit 79823c87 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Update to dropout kernel/device-op/example

parent 064f596c
...@@ -29,7 +29,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -29,7 +29,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/host_common_util.hpp" //#include "ck/library/utility/host_common_util.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -316,17 +316,16 @@ using DeviceDropoutInstance = ...@@ -316,17 +316,16 @@ using DeviceDropoutInstance =
TensorSpecB1, TensorSpecB1,
TensorSpecC, TensorSpecC,
256, // BlockSize 256, // BlockSize
128, // MPerBlock 64, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 128, // Gemm1NPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 2, // MXdlPerWave
4>; // NXdlPerWave 1>; // NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc" #include "run_batched_multihead_attention_bias_forward_zcheck.inc"
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
using ck::host_common::dumpBufferToFile; // using ck::host_common::dumpBufferToFile;
int init_method = 1; int init_method = 1;
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 200; // 120 ck::index_t M = 1000; // 120
ck::index_t N = 200; // 1000 ck::index_t N = 1000; // 1000
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
...@@ -225,7 +225,7 @@ int run(int argc, char* argv[]) ...@@ -225,7 +225,7 @@ int run(int argc, char* argv[])
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
dumpBufferToFile("forward_z.dat", z_gs_ms_ns.mData.data(), z_gs_ms_ns.mData.size()); // dumpBufferToFile("forward_z.dat", z_gs_ms_ns.mData.data(), z_gs_ms_ns.mData.size());
// do Dropout // do Dropout
auto dropout_op = DeviceDropoutInstance(); auto dropout_op = DeviceDropoutInstance();
...@@ -245,7 +245,7 @@ int run(int argc, char* argv[]) ...@@ -245,7 +245,7 @@ int run(int argc, char* argv[])
z_device_buf_2.FromDevice(z_gs_ms_ns_2.mData.data()); z_device_buf_2.FromDevice(z_gs_ms_ns_2.mData.data());
dumpBufferToFile("canonic_z.dat", z_gs_ms_ns_2.mData.data(), z_gs_ms_ns_2.mData.size()); // dumpBufferToFile("canonic_z.dat", z_gs_ms_ns_2.mData.data(), z_gs_ms_ns_2.mData.size());
return ck::utils::check_integer_err(z_gs_ms_ns.mData, z_gs_ms_ns_2.mData, 1.0e-5); return ck::utils::check_integer_err(z_gs_ms_ns.mData, z_gs_ms_ns_2.mData, 1.0e-5);
} }
...@@ -105,7 +105,6 @@ template <index_t NumDimG, ...@@ -105,7 +105,6 @@ template <index_t NumDimG,
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t MXdlPerWave, index_t MXdlPerWave,
......
...@@ -343,6 +343,10 @@ struct GridwiseBatchedDropout ...@@ -343,6 +343,10 @@ struct GridwiseBatchedDropout
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D( auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
...@@ -352,9 +356,6 @@ struct GridwiseBatchedDropout ...@@ -352,9 +356,6 @@ struct GridwiseBatchedDropout
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock); __builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
// save z to global // save z to global
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment