Commit 5012068b authored by ltqin's avatar ltqin
Browse files

start adding drop in device

parent 17bb1aaa
......@@ -10,6 +10,7 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -255,12 +255,6 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
float rp_dropout = 1.0 / p_dropout;
float scale_rp_dropout = alpha * rp_dropout;
if(argc == 1)
{
......@@ -485,7 +479,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{scale_rp_dropout}, //dQ *= scale_rp_dropout
Scale{alpha},
QKVElementOp{},
YElementOp{});
......
......@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -97,14 +97,6 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
float p_dropout = 1 - 0.2;
const ushort p_dropout_in_16bits = 65536 * p_dropout;
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 0;
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
......@@ -128,11 +120,7 @@ __global__ void
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout,
rp_dropout,
ph);
c0_matrix_mask);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -567,7 +555,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle<
DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmAccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment