Commit f82a220f authored by guangzlu's avatar guangzlu
Browse files

v4 pass

parent ff88ffa4
...@@ -766,7 +766,7 @@ int run(int argc, char* argv[]) ...@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -710,17 +710,17 @@ int run(int argc, char* argv[]) ...@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 1000; // 512 ck::index_t M = 500; // 512
ck::index_t N = 1000; // 512 ck::index_t N = 500; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 2; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 1; // 16
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.0; float p_drop = 0.1;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -944,7 +944,7 @@ int run(int argc, char* argv[]) ...@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
...@@ -998,7 +998,7 @@ int run(int argc, char* argv[]) ...@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
...@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[]) ...@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData, pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData, qgrad_gs_ms_ks_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
std::cout << "Checking kgrad:\n"; std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData, pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData, kgrad_gs_ns_ks_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
std::cout << "Checking vgrad:\n"; std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData, pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData, vgrad_gs_os_ns_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
} }
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
...@@ -145,14 +145,14 @@ struct BlockwiseDropout ...@@ -145,14 +145,14 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -162,7 +162,7 @@ struct BlockwiseDropout ...@@ -162,7 +162,7 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -208,17 +208,17 @@ struct BlockwiseDropout ...@@ -208,17 +208,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -226,7 +226,7 @@ struct BlockwiseDropout ...@@ -226,7 +226,7 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -40,6 +41,7 @@ template <typename GridwiseGemm, ...@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -73,6 +75,8 @@ __global__ void ...@@ -73,6 +75,8 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -138,6 +142,7 @@ __global__ void ...@@ -138,6 +142,7 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -173,6 +178,7 @@ __global__ void ...@@ -173,6 +178,7 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(
z_grid_desc_m_n_);
// tmp z tensor for shuffle
// Tensor<ZDataType> z_tmp_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
// DeviceMem z_tmp_device_buf(sizeof(ZDataType) *
// z_tmp_gs_ms_ns.mDesc.GetElementSpaceSize());
// z_tmp_device_buf.ToDevice(z_tmp_gs_ms_ns.mData.data());
// p_z_tmp_grid_ = reinterpret_cast<ZDataType*>(z_tmp_device_buf.GetDeviceBuffer());
// Print(); // Print();
} }
...@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
// ZDataType* p_z_tmp_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
...@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment