Commit 68903d4d authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Simplification by removing some templates in DeviceBatchedDropout

parent 3c1d5412
...@@ -301,31 +301,25 @@ using DeviceGemmInstance = ...@@ -301,31 +301,25 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#endif #endif
using DeviceDropoutInstance = using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
ck::tensor_operation::device::DeviceBatchedDropout<NumDimG, GemmDataType,
NumDimM, ZDataType,
NumDimN, GemmDataType,
NumDimK, GemmSpec,
NumDimO, TensorSpecA,
GemmDataType, TensorSpecB0,
ZDataType, TensorSpecB1,
GemmDataType, TensorSpecC,
GemmSpec, 256, // BlockSize
TensorSpecA, 64, // MPerBlock
TensorSpecB0, 128, // NPerBlock
TensorSpecB1, 32, // KPerBlock
TensorSpecC, 8, // AK1
256, // BlockSize 8, // BK1
64, // MPerBlock 32, // MPerXDL
128, // NPerBlock 32, // NPerXDL
32, // KPerBlock 2, // MXdlPerWave
128, // Gemm1NPerBlock 1>; // NXdlPerWave
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1>; // NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc" #include "run_batched_multihead_attention_bias_forward_zcheck.inc"
......
...@@ -85,10 +85,6 @@ __global__ void ...@@ -85,10 +85,6 @@ __global__ void
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -101,7 +97,6 @@ template <index_t NumDimG, ...@@ -101,7 +97,6 @@ template <index_t NumDimG,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t MPerXDL, index_t MPerXDL,
...@@ -110,17 +105,18 @@ template <index_t NumDimG, ...@@ -110,17 +105,18 @@ template <index_t NumDimG,
index_t NXdlPerWave> index_t NXdlPerWave>
struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0, "Number of dimension must be greater than 0");
"Number of dimension must be greater than 0");
using DeviceOp = DeviceBatchedDropout; using DeviceOp = DeviceBatchedDropout;
static constexpr index_t Gemm1NPerBlock = 128;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>, Sequence<NumDimG, 1, 1, 1, 1>, // NumDimM, NumDimN, NumDimK, NumDimO
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>, Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec, GemmSpec,
ASpec, ASpec,
...@@ -128,31 +124,19 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -128,31 +124,19 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
B1Spec, B1Spec,
CSpec>; CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides) const std::vector<index_t>& a_gs_m_k_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides), Transform::MakeAGridDescriptor_M_K(a_gs_m_k_lengths, a_gs_m_k_strides), Number<AK1>{});
Number<AK1>{});
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_m_n_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeCGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides);
} }
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
...@@ -198,23 +182,23 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -198,23 +182,23 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(ZDataType* p_z_grid, Argument(ZDataType* p_z_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: p_z_grid_{p_z_grid}, : p_z_grid_{p_z_grid},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_n_k_lengths, b_gs_n_k_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_m_k_lengths[NumDimG],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_n_k_lengths[NumDimG],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1]}, b_gs_n_k_lengths[NumDimG + 1]},
batch_count_{z_grid_desc_g_m_n_.GetLength(I0)} batch_count_{z_grid_desc_g_m_n_.GetLength(I0)}
{ {
...@@ -229,7 +213,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -229,7 +213,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// Print(); // Print();
auto a_grid_desc_k0_m_k1 = auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_m_k_lengths, a_gs_m_k_strides);
num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock; num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
...@@ -348,21 +332,21 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -348,21 +332,21 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
} }
static auto MakeArgument(ZDataType* p_z, static auto MakeArgument(ZDataType* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_z, return Argument{p_z,
a_gs_ms_ks_lengths, a_gs_m_k_lengths,
a_gs_ms_ks_strides, a_gs_m_k_strides,
b_gs_ns_ks_lengths, b_gs_n_k_lengths,
b_gs_ns_ks_strides, b_gs_n_k_strides,
z_gs_ms_ns_lengths, z_gs_m_n_lengths,
z_gs_ms_ns_strides, z_gs_m_n_strides,
seeds}; seeds};
} }
...@@ -372,21 +356,21 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -372,21 +356,21 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// FIXME: constness // FIXME: constness
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_z, MakeArgumentPointer(void* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<ZDataType*>(p_z), return std::make_unique<Argument>(static_cast<ZDataType*>(p_z),
a_gs_ms_ks_lengths, a_gs_m_k_lengths,
a_gs_ms_ks_strides, a_gs_m_k_strides,
b_gs_ns_ks_lengths, b_gs_n_k_lengths,
b_gs_ns_ks_strides, b_gs_n_k_strides,
z_gs_ms_ns_lengths, z_gs_m_n_lengths,
z_gs_ms_ns_strides, z_gs_m_n_strides,
seeds); seeds);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment