Commit e3a2651b authored by qin letao's avatar qin letao
Browse files

add dropout gridwise bwd

parent 70f2adc5
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO // #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
...@@ -15,7 +16,7 @@ ...@@ -15,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -96,6 +97,14 @@ __global__ void ...@@ -96,6 +97,14 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
float p_dropout = 1 - 0.2;
const ushort p_dropout_in_16bits = 65536 * p_dropout;
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 0;
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
...@@ -119,7 +128,10 @@ __global__ void ...@@ -119,7 +128,10 @@ __global__ void
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask); c0_matrix_mask,
p_dropout_in_16bits,
rp_dropout,
ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -554,7 +566,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -554,7 +566,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, LSEDataType,
GemmAccDataType, GemmAccDataType,
......
...@@ -1404,7 +1404,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1404,7 +1404,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{ auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; //relu(P-dropped) tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy // dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = auto vgrad_gemm_tile_ygrad_blockwise_copy =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment