Commit f82a220f authored by guangzlu's avatar guangzlu
Browse files

v4 pass

parent ff88ffa4
......@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......
......@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 1000; // 512
ck::index_t N = 1000; // 512
ck::index_t M = 500; // 512
ck::index_t N = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
ck::index_t G0 = 2; // 54
ck::index_t G1 = 1; // 16
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.0;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
......@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
......@@ -145,14 +145,14 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds();
......@@ -162,7 +162,7 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index];
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
......@@ -208,17 +208,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
......@@ -226,7 +226,7 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index];
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
......
......@@ -22,6 +22,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
namespace ck {
namespace tensor_operation {
......@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
......@@ -73,6 +75,8 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -138,6 +142,7 @@ __global__ void
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -173,6 +178,7 @@ __global__ void
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(
z_grid_desc_m_n_);
// tmp z tensor for shuffle
// Tensor<ZDataType> z_tmp_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
// DeviceMem z_tmp_device_buf(sizeof(ZDataType) *
// z_tmp_gs_ms_ns.mDesc.GetElementSpaceSize());
// z_tmp_device_buf.ToDevice(z_tmp_gs_ms_ns.mData.data());
// p_z_tmp_grid_ = reinterpret_cast<ZDataType*>(z_tmp_device_buf.GetDeviceBuffer());
// Print();
}
......@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// pointers
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
// ZDataType* p_z_tmp_grid_;
ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
......@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
......@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment