Commit ff88ffa4 authored by guangzlu's avatar guangzlu
Browse files

bwd pass for v4

parent 01073007
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -497,7 +497,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.0;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
......@@ -145,6 +145,15 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
}
block_sync_lds();
int tmp_index = 0;
......@@ -153,7 +162,71 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
z_thread_buf(offset) = tmp_id[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout_v2(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrix(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index];
tmp_index = tmp_index + 1;
});
});
......
......@@ -1600,13 +1600,83 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
z_tenor_buffer;
z_tenor_buffer.Clear();
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer_tmp;
z_tenor_buffer_tmp.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf_tmp =
make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}};
auto z_tmp_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<ZDataType,
ushort,
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
Sequence<I1, I1, m0, n0, m1, n1, m2, m3, m4, n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
1,
true /* ResetCoordAfterRun */>{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1])};
auto z_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
......@@ -1986,9 +2056,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
......@@ -2001,10 +2074,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// printf("id_step is %d \n", id_step);
//}
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
// generate random number
blockwise_dropout.template GenerateZMatrix<decltype(z_tenor_buffer_tmp)>(
ph, global_elem_id, z_tenor_buffer_tmp, MRaw);
z_tmp_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer_tmp,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf_tmp);
z_tmp_thread_copy_global_to_vgpr.Run(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf_tmp,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
blockwise_dropout.template ApplyDropout_v2<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer);
true>(s_slash_p_thread_buf,
z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
......@@ -2079,6 +2173,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore = z_grid_buf;
ignore = z_grid_buf_tmp;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id);
......@@ -2266,6 +2361,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm1::b_block_reset_copy_step); // rewind M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_tmp_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment