"profiler/vscode:/vscode.git/clone" did not exist on "03b36fbc829d4aed0c60a3b5aa5640b20865cf13"
Commit 36dc18e8 authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-bwd-develop' into attn-bwd-dropout-pt1

parents 56befb6e 6fd1490b
......@@ -10,13 +10,12 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -52,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -82,7 +81,7 @@ __global__ void
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_dropout,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset)
{
......@@ -137,7 +136,7 @@ __global__ void
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_dropout,
p_drop,
ph);
#else
ignore = p_a_grid;
......@@ -157,6 +156,9 @@ __global__ void
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -756,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
......@@ -777,10 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_grid_desc_m_o_);
}
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -871,7 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
float p_drop_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -898,7 +897,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
ZDataType,
......@@ -953,7 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_,
arg.p_drop_,
arg.seed_,
arg.offset_);
};
......
......@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
__host__ __device__ auto Value() const { return scale_; }
float scale_;
};
......
......@@ -1185,11 +1185,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
const float p_drop,
ck::philox& ph)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -1509,7 +1513,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -1619,9 +1623,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op);
scale_rp_dropout);
//
// set up Y dot dY
......@@ -1764,8 +1768,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1846,14 +1848,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
else
{
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......@@ -1863,6 +1866,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(is_dropout)
{
if(p_z_grid)
{
// P_dropped
......@@ -1871,7 +1876,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -1883,6 +1889,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2241,7 +2248,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
s_element_op};
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment