Commit 9ec592a6 authored by guangzlu's avatar guangzlu
Browse files

added new dropout based on global position for fwd pass

parent a3def557
......@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
ck::index_t M = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
ck::index_t G0 = 1; // 54
ck::index_t G1 = 2; // 16
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.0;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
......@@ -122,6 +122,152 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout_v1r1(CThreadBuffer& in_thread_buf, ck::philox& ph, index_t element_global_1d_id)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout_v1r2(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout_v2(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrix(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
......
......@@ -79,7 +79,9 @@ __global__ void
const ushort p_dropout_in_16bits,
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset)
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -131,6 +133,9 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
i);
}
}
......@@ -160,6 +165,9 @@ __global__ void
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
0);
}
#else
......@@ -803,7 +811,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_,
arg.offset_);
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
......@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_m)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -917,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
block_sync_lds();
}
do
{
auto n_block_data_idx_on_grid =
......@@ -1012,31 +1015,84 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if constexpr(IsDropout) // dropout
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
if(p_z_grid)
{
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
blockwise_dropout.template ApplyDropout_v1r2<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false>(
acc_thread_buf, ph, global_elem_id, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_grid_buf);
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
// decltype(z_tenor_buffer),
// false,
// decltype(n0),
// decltype(i)>(
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
......@@ -1045,8 +1101,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph);
blockwise_dropout.template ApplyDropout_v1r1<decltype(acc_thread_buf), false>(
acc_thread_buf, ph, global_elem_id);
}
}
......
......@@ -84,6 +84,17 @@ class philox
out_tmp[3] = tmp_ph.w;
}
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
out[0] = static_cast<ushort>(tmp_ph.x);
out[1] = static_cast<ushort>(tmp_ph.y);
out[2] = static_cast<ushort>(tmp_ph.z);
out[3] = static_cast<ushort>(tmp_ph.w);
}
private:
struct ull2
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment