Commit 6d2e3152 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Simplify DeviceBatchedDropout MakeArgument parameters

parent 68903d4d
...@@ -233,10 +233,6 @@ int run(int argc, char* argv[]) ...@@ -233,10 +233,6 @@ int run(int argc, char* argv[])
auto dropout_arg = auto dropout_arg =
dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()), dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()),
a_gs_ms_ks_lengths,
{0, 0, 0, 0},
b0_gs_ns_ks_lengths,
{0, 0, 0, 0},
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
{seed, offset}); {seed, offset});
......
...@@ -182,8 +182,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -182,8 +182,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(ZDataType* p_z_grid, Argument(ZDataType* p_z_grid,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths, const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides, const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
...@@ -196,9 +194,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -196,9 +194,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_m_n_lengths, z_gs_m_n_strides)},
block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_m_k_lengths[NumDimG],
b_gs_n_k_lengths[NumDimG],
b_gs_n_k_lengths[NumDimG + 1]},
batch_count_{z_grid_desc_g_m_n_.GetLength(I0)} batch_count_{z_grid_desc_g_m_n_.GetLength(I0)}
{ {
...@@ -210,15 +205,22 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -210,15 +205,22 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n_); z_grid_desc_m_n_);
// Print();
auto m_raw = z_gs_m_n_lengths[NumDimG];
auto n_raw = z_gs_m_n_lengths[NumDimG + 1];
m_raw_padded_ = GridwiseDropout::GetPaddedSize(m_raw);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(n_raw);
std::vector<index_t> a_gs_m_k_strides(NumDimG + 2, 0);
std::vector<index_t> a_gs_m_k_lengths = z_gs_m_n_lengths;
a_gs_m_k_lengths[NumDimG + 1] = 1;
auto a_grid_desc_k0_m_k1 = auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_m_k_lengths, a_gs_m_k_strides); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_m_k_lengths, a_gs_m_k_strides);
num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock; num_gemm0_m_block_outer_loop_ = a_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
} }
// pointers // pointers
...@@ -237,9 +239,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -237,9 +239,6 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
// block-to-c-tile map // block-to-c-tile map
typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t num_gemm0_m_block_outer_loop_; index_t num_gemm0_m_block_outer_loop_;
index_t batch_count_; index_t batch_count_;
...@@ -332,41 +331,35 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -332,41 +331,35 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
} }
static auto MakeArgument(ZDataType* p_z, static auto MakeArgument(ZDataType* p_z,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides, const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_z, std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
a_gs_m_k_lengths, std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
a_gs_m_k_strides,
b_gs_n_k_lengths, b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_strides, b_gs_n_k_lengths[NumDimG + 1] = 1;
z_gs_m_n_lengths,
z_gs_m_n_strides, return Argument{
seeds}; p_z, b_gs_n_k_lengths, b_gs_n_k_strides, z_gs_m_n_lengths, z_gs_m_n_strides, seeds};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_z, MakeArgumentPointer(void* p_z,
const std::vector<index_t>& a_gs_m_k_lengths,
const std::vector<index_t>& a_gs_m_k_strides,
const std::vector<index_t>& b_gs_n_k_lengths,
const std::vector<index_t>& b_gs_n_k_strides,
const std::vector<index_t>& z_gs_m_n_lengths, const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides, const std::vector<index_t>& z_gs_m_n_strides,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
std::vector<index_t> b_gs_n_k_strides(NumDimG + 2, 0);
std::vector<index_t> b_gs_n_k_lengths = z_gs_m_n_lengths;
b_gs_n_k_lengths[NumDimG] = z_gs_m_n_lengths[NumDimG + 1];
b_gs_n_k_lengths[NumDimG + 1] = 1;
return std::make_unique<Argument>(static_cast<ZDataType*>(p_z), return std::make_unique<Argument>(static_cast<ZDataType*>(p_z),
a_gs_m_k_lengths,
a_gs_m_k_strides,
b_gs_n_k_lengths, b_gs_n_k_lengths,
b_gs_n_k_strides, b_gs_n_k_strides,
z_gs_m_n_lengths, z_gs_m_n_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment