Commit 203268f1 authored by guangzlu's avatar guangzlu
Browse files

added gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp

parents d0c65caa 82ce7f4e
...@@ -401,9 +401,9 @@ int run(int argc, char* argv[]) ...@@ -401,9 +401,9 @@ int run(int argc, char* argv[])
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
......
...@@ -44,7 +44,7 @@ struct BlockwiseDropout ...@@ -44,7 +44,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -79,7 +79,7 @@ struct BlockwiseDropout ...@@ -79,7 +79,7 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
......
...@@ -10,13 +10,12 @@ ...@@ -10,13 +10,12 @@
#include "ck/utility/philox_rand.hpp" #include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -52,7 +51,7 @@ __global__ void ...@@ -52,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -82,7 +81,7 @@ __global__ void ...@@ -82,7 +81,7 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_dropout, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
...@@ -137,7 +136,7 @@ __global__ void ...@@ -137,7 +136,7 @@ __global__ void
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout, p_drop,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -157,6 +156,9 @@ __global__ void ...@@ -157,6 +156,9 @@ __global__ void
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -756,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -756,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -777,10 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -777,10 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -871,7 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -871,7 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_drop_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -898,7 +897,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -898,7 +897,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType, ZDataType,
...@@ -953,7 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -953,7 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
......
...@@ -95,7 +95,7 @@ struct Scale ...@@ -95,7 +95,7 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; } __host__ __device__ auto Value() const { return scale_; }
float scale_; float scale_;
}; };
......
...@@ -48,10 +48,14 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -48,10 +48,14 @@ struct ReferenceDropout : public device::BaseOperator
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = self(idx) =
<<<<<<< HEAD
arg.ref_(idx) < arg.p_dropout_in_16bits_ arg.ref_(idx) < arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) * ? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_)) ck::type_convert<float>(arg.rp_dropout_))
: 0; : 0;
=======
arg.ref_(idx) <= arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0;
>>>>>>> attn-bwd-develop
}); });
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment