Commit 7a8352bc authored by danyao12's avatar danyao12
Browse files

fix split kernels dropout related bugs

parent edbb3439
...@@ -46,7 +46,8 @@ __global__ void ...@@ -46,7 +46,8 @@ __global__ void
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const float p_drop)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -66,7 +67,8 @@ __global__ void ...@@ -66,7 +67,8 @@ __global__ void
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map,
p_drop);
#else #else
ignore = p_y_grid; ignore = p_y_grid;
...@@ -77,6 +79,7 @@ __global__ void ...@@ -77,6 +79,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = p_drop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -1131,7 +1134,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1131,7 +1134,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.d_grid_desc_m_, arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_, arg.d_block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.p_drop_);
}; };
ave_time = launch_kernel(); ave_time = launch_kernel();
......
...@@ -46,7 +46,8 @@ __global__ void ...@@ -46,7 +46,8 @@ __global__ void
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const float p_drop)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -66,7 +67,8 @@ __global__ void ...@@ -66,7 +67,8 @@ __global__ void
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map,
p_drop);
#else #else
ignore = p_y_grid; ignore = p_y_grid;
...@@ -77,6 +79,7 @@ __global__ void ...@@ -77,6 +79,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = p_drop;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -1143,7 +1146,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1143,7 +1146,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.d_grid_desc_m_, arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_, arg.d_block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.p_drop_);
}; };
ave_time = launch_kernel(); ave_time = launch_kernel();
......
...@@ -33,7 +33,9 @@ __global__ void ...@@ -33,7 +33,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_ydotygrad_v1( kernel_grouped_multihead_attention_backward_ydotygrad_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const float p_dropout)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -74,10 +76,12 @@ __global__ void ...@@ -74,10 +76,12 @@ __global__ void
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_,
p_dropout);
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = p_dropout;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -1175,7 +1179,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1175,7 +1179,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_); arg.group_count_,
arg.p_dropout_);
}; };
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
......
...@@ -32,7 +32,9 @@ __global__ void ...@@ -32,7 +32,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_ydotygrad_v2( kernel_grouped_multihead_attention_backward_ydotygrad_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const float p_dropout)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -73,10 +75,12 @@ __global__ void ...@@ -73,10 +75,12 @@ __global__ void
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_,
p_dropout);
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = p_dropout;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -1244,7 +1248,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1244,7 +1248,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_); arg.group_count_,
arg.p_dropout_);
}; };
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
......
...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const float p_drop)
{ {
const FloatD p_dropout = type_convert<FloatD>(1.0f - p_drop);
const tensor_operation::element_wise::Scale scale_p_dropout(p_dropout);
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD, FloatD,
decltype(d_thread_desc_mblock_m1), decltype(d_thread_desc_mblock_m1),
decltype(d_grid_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::Scale,
Sequence<1, 1>, Sequence<1, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperblock get_thread_local_1d_id()), // mperblock
ck::tensor_operation::element_wise::PassThrough{}}; scale_p_dropout};
// copy from VGPR to Global // copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1, d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment