Unverified Commit 3c1d5412 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #885 from ROCmSoftwarePlatform/mha-dropout-update

Simple batched dropout update
parents e8ef00ec e5f7c969
...@@ -234,9 +234,9 @@ int run(int argc, char* argv[]) ...@@ -234,9 +234,9 @@ int run(int argc, char* argv[])
auto dropout_arg = auto dropout_arg =
dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()), dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, {0, 0, 0, 0},
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides, {0, 0, 0, 0},
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
{seed, offset}); {seed, offset});
......
...@@ -27,7 +27,6 @@ namespace device { ...@@ -27,7 +27,6 @@ namespace device {
template <typename GridwiseDropout_, template <typename GridwiseDropout_,
typename ZDataType, typename ZDataType,
typename AGridDesc_AK0_M_AK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch> typename ComputeBasePtrOfStridedBatch>
...@@ -36,10 +35,10 @@ __global__ void ...@@ -36,10 +35,10 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_dropout(ZDataType* __restrict__ p_z_grid, kernel_batched_dropout(ZDataType* __restrict__ p_z_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t num_gemm0_m_block_outer_loop,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const unsigned long long seed, const unsigned long long seed,
...@@ -62,10 +61,10 @@ __global__ void ...@@ -62,10 +61,10 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
GridwiseDropout_::Run(z_matrix_ptr, GridwiseDropout_::Run(z_matrix_ptr,
a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
block_2_ctile_map, block_2_ctile_map,
ph, ph,
num_gemm0_m_block_outer_loop,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded); raw_n_padded);
#else #else
...@@ -156,8 +155,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -156,8 +155,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -182,7 +180,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -182,7 +180,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
using GridwiseDropout = GridwiseBatchedDropout<ZDataType, using GridwiseDropout = GridwiseBatchedDropout<ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
AGridDesc_AK0_M_AK1,
KGridDesc_N_K, KGridDesc_N_K,
ZGridDesc_M_N, ZGridDesc_M_N,
BlockSize, BlockSize,
...@@ -209,8 +206,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -209,8 +206,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: p_z_grid_{p_z_grid}, : p_z_grid_{p_z_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
...@@ -233,6 +228,11 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -233,6 +228,11 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_m_n_); z_grid_desc_m_n_);
// Print(); // Print();
auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]); n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
} }
...@@ -241,7 +241,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -241,7 +241,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -257,6 +256,8 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -257,6 +256,8 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t num_gemm0_m_block_outer_loop_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
...@@ -288,7 +289,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -288,7 +289,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
const auto kernel = kernel_batched_dropout< const auto kernel = kernel_batched_dropout<
GridwiseDropout, GridwiseDropout,
ZDataType, ZDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseDropout::DefaultBlock2CTileMap, typename GridwiseDropout::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>; ComputeBasePtrOfStridedBatch>;
...@@ -299,9 +299,9 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -299,9 +299,9 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_z_grid_, arg.p_z_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.num_gemm0_m_block_outer_loop_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.seed_, arg.seed_,
......
...@@ -19,7 +19,6 @@ namespace ck { ...@@ -19,7 +19,6 @@ namespace ck {
template <typename ZDataType, template <typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename QGridDesc_K0_M_K1,
typename KGridDesc_N_K, typename KGridDesc_N_K,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
index_t BlockSize, index_t BlockSize,
...@@ -203,11 +202,11 @@ struct GridwiseBatchedDropout ...@@ -203,11 +202,11 @@ struct GridwiseBatchedDropout
template <typename Block2CTileMap> template <typename Block2CTileMap>
__device__ static void Run(ZDataType* __restrict__ p_z_grid, __device__ static void Run(ZDataType* __restrict__ p_z_grid,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
ck::philox& ph, ck::philox& ph,
const index_t num_gemm0_m_block_outer_loop,
const index_t z_random_matrix_offset, const index_t z_random_matrix_offset,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
...@@ -219,8 +218,6 @@ struct GridwiseBatchedDropout ...@@ -219,8 +218,6 @@ struct GridwiseBatchedDropout
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
// S: blockwise gemm // S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment