Commit 6d2e3152 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Simplify DeviceBatchedDropout MakeArgument parameters

parent 68903d4d
......@@ -233,10 +233,6 @@ int run(int argc, char* argv[])
auto dropout_arg =
dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()),
a_gs_ms_ks_lengths,
{0, 0, 0, 0},
b0_gs_ns_ks_lengths,
{0, 0, 0, 0},
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
{seed, offset});
......
......@@ -182,8 +182,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
struct Argument : public BaseArgument
{
Argument(ZDataType* p_z_grid,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
......@@ -196,9 +194,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_m_k_lengths[NumDimG],
b_gs_n_k_lengths[NumDimG],
b_gs_n_k_lengths[NumDimG + 1]},
batch_count_{z_grid_desc_g_m_n_.GetLength(I0)}
{
......@@ -210,15 +205,22 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n_);
// Print();
auto m_raw = z_gs_m_n_lengths[NumDimG];
auto n_raw = z_gs_m_n_lengths[NumDimG + 1];
m_raw_padded_ = GridwiseDropout::GetPaddedSize(m_raw);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(n_raw);
std::vector<index_t> a_gs_m_k_strides(NumDimG + 2, 0);
std::vector<index_t> a_gs_m_k_lengths = z_gs_m_n_lengths;
a_gs_m_k_lengths[NumDimG + 1] = 1;
auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_m_k_lengths, a_gs_m_k_strides);
num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
}
// pointers
......@@ -237,9 +239,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// block-to-c-tile map
typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t num_gemm0_m_block_outer_loop_;
index_t batch_count_;
......@@ -332,41 +331,35 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
}
static auto MakeArgument(ZDataType* p_z,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_z,
a_gs_m_k_lengths,
a_gs_m_k_strides,
b_gs_n_k_lengths,
b_gs_n_k_strides,
z_gs_m_n_lengths,
z_gs_m_n_strides,
seeds};
std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_lengths[NumDimG + 1] = 1;
return Argument{
p_z, b_gs_n_k_lengths, b_gs_n_k_strides, z_gs_m_n_lengths, z_gs_m_n_strides, seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_z,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_lengths[NumDimG + 1] = 1;
return std::make_unique<Argument>(static_cast<ZDataType*>(p_z),
a_gs_m_k_lengths,
a_gs_m_k_strides,
b_gs_n_k_lengths,
b_gs_n_k_strides,
z_gs_m_n_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment