Commit 36dc18e8 authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-dropout-pt1

parents 56befb6e 6fd1490b
...@@ -10,13 +10,12 @@ ...@@ -10,13 +10,12 @@
#include "ck/utility/philox_rand.hpp" #include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -52,7 +51,7 @@ __global__ void ...@@ -52,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -82,7 +81,7 @@ __global__ void ...@@ -82,7 +81,7 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_dropout, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
...@@ -137,7 +136,7 @@ __global__ void ...@@ -137,7 +136,7 @@ __global__ void
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout, p_drop,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -157,6 +156,9 @@ __global__ void ...@@ -157,6 +156,9 @@ __global__ void
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -756,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -756,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -777,10 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -777,10 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -871,7 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -871,7 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_drop_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -898,7 +897,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -898,7 +897,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType, ZDataType,
...@@ -953,7 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -953,7 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
......
...@@ -95,7 +95,7 @@ struct Scale ...@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; } __host__ __device__ auto Value() const { return scale_; }
float scale_; float scale_;
}; };
......
...@@ -1185,11 +1185,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1185,11 +1185,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout, const float p_drop,
ck::philox& ph) ck::philox& ph)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout; const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -1509,7 +1513,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1509,7 +1513,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n4>, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim 9, // DstVectorDim
n4, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1619,9 +1623,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1619,9 +1623,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy< auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op); scale_rp_dropout);
// //
// set up Y dot dY // set up Y dot dY
...@@ -1764,8 +1768,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1764,8 +1768,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock; const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1846,14 +1848,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1846,14 +1848,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
else else
{ {
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; }); s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
...@@ -1863,6 +1866,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1863,6 +1866,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global // save z to global
if(is_dropout)
{
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
...@@ -1871,7 +1876,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1871,7 +1876,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true>( true>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1883,6 +1889,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1883,6 +1889,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -2241,7 +2248,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2241,7 +2248,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_thread_data_on_block_idx[I2], n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3], n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]), n_thread_data_on_block_idx[I4]),
s_element_op}; scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment