Commit edbb3439 authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop' into mha-train-develop-bwdopt-bias

parents 31706d42 1f04cd2b
......@@ -301,31 +301,25 @@ using DeviceGemmInstance =
Deterministic>;
#endif
using DeviceDropoutInstance =
ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
GemmDataType,
ZDataType,
GemmDataType,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
256, // BlockSize
64, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1>; // NXdlPerWave
using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
GemmDataType,
ZDataType,
GemmDataType,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
256, // BlockSize
64, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1>; // NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc"
......
......@@ -233,10 +233,6 @@ int run(int argc, char* argv[])
auto dropout_arg =
dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
{seed, offset});
......
......@@ -27,7 +27,6 @@ namespace device {
template <typename GridwiseDropout_,
typename ZDataType,
typename AGridDesc_AK0_M_AK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
......@@ -36,10 +35,10 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_dropout(ZDataType* __restrict__ p_z_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const Block2CTileMap block_2_ctile_map,
const index_t num_gemm0_m_block_outer_loop,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const unsigned long long seed,
......@@ -62,10 +61,10 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
GridwiseDropout_::Run(z_matrix_ptr,
a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
block_2_ctile_map,
ph,
num_gemm0_m_block_outer_loop,
z_random_matrix_offset,
raw_n_padded);
#else
......@@ -86,10 +85,6 @@ __global__ void
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename GemmDataType,
typename ZDataType,
typename GemmAccDataType,
......@@ -102,7 +97,6 @@ template <index_t NumDimG,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
......@@ -111,17 +105,18 @@ template <index_t NumDimG,
index_t NXdlPerWave>
struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static_assert(NumDimG > 0, "Number of dimension must be greater than 0");
using DeviceOp = DeviceBatchedDropout;
static constexpr index_t Gemm1NPerBlock = 128;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<NumDimG, 1, 1, 1, 1>, // NumDimM, NumDimN, NumDimK, NumDimO
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
......@@ -129,35 +124,22 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
B1Spec,
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides)
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{});
Transform::MakeAGridDescriptor_M_K(a_gs_m_k_lengths, a_gs_m_k_strides), Number<AK1>{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeCGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
......@@ -182,7 +164,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
using GridwiseDropout = GridwiseBatchedDropout<ZDataType,
GemmDataType,
GemmAccDataType,
AGridDesc_AK0_M_AK1,
KGridDesc_N_K,
ZGridDesc_M_N,
BlockSize,
......@@ -201,25 +182,18 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
struct Argument : public BaseArgument
{
Argument(ZDataType* p_z_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_z_grid_{p_z_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
Transform::MakeB0GridDescriptor_N_K(b_gs_n_k_lengths, b_gs_n_k_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeCGridDescriptor_G_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1]},
batch_count_{z_grid_desc_g_m_n_.GetLength(I0)}
{
......@@ -231,17 +205,28 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n_);
// Print();
m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
auto m_raw = z_gs_m_n_lengths[NumDimG];
auto n_raw = z_gs_m_n_lengths[NumDimG + 1];
m_raw_padded_ = GridwiseDropout::GetPaddedSize(m_raw);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(n_raw);
std::vector<index_t> a_gs_m_k_strides(NumDimG + 2, 0);
std::vector<index_t> a_gs_m_k_lengths = z_gs_m_n_lengths;
a_gs_m_k_lengths[NumDimG + 1] = 1;
auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_m_k_lengths, a_gs_m_k_strides);
num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
}
// pointers
ZDataType* p_z_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
ZGridDesc_M_N z_grid_desc_m_n_;
KGridDesc_N_K k_grid_desc_n_k_;
......@@ -254,8 +239,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// block-to-c-tile map
typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t num_gemm0_m_block_outer_loop_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
......@@ -288,7 +272,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
const auto kernel = kernel_batched_dropout<
GridwiseDropout,
ZDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseDropout::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
......@@ -299,9 +282,9 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
dim3(BlockSize),
0,
arg.p_z_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.block_2_ctile_map_,
arg.num_gemm0_m_block_outer_loop_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.seed_,
......@@ -348,45 +331,39 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
}
static auto MakeArgument(ZDataType* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_z,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
seeds};
std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_lengths[NumDimG + 1] = 1;
return Argument{
p_z, b_gs_n_k_lengths, b_gs_n_k_strides, z_gs_m_n_lengths, z_gs_m_n_strides, seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_lengths[NumDimG + 1] = 1;
return std::make_unique<Argument>(static_cast<ZDataType*>(p_z),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b_gs_n_k_lengths,
b_gs_n_k_strides,
z_gs_m_n_lengths,
z_gs_m_n_strides,
seeds);
}
......
......@@ -19,7 +19,6 @@ namespace ck {
template <typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename QGridDesc_K0_M_K1,
typename KGridDesc_N_K,
typename ZGridDesc_M_N,
index_t BlockSize,
......@@ -203,11 +202,11 @@ struct GridwiseBatchedDropout
template <typename Block2CTileMap>
__device__ static void Run(ZDataType* __restrict__ p_z_grid,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const Block2CTileMap& block_2_ctile_map,
ck::philox& ph,
const index_t num_gemm0_m_block_outer_loop,
const index_t z_random_matrix_offset,
const index_t raw_n_padded)
{
......@@ -219,8 +218,6 @@ struct GridwiseBatchedDropout
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment