Commit 9241ff62 authored by guangzlu's avatar guangzlu
Browse files

modified method to set offset in philox

parent 11eed39f
......@@ -178,7 +178,8 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
0); // dropout ratio
0, // dropout ratio
{0, 64}); // dropout random seed and offset
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -218,7 +218,8 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
0); // dropout ratio
0, // dropout ratio
{0, 448}); // dropout random seed and offset
// specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......
......@@ -17,10 +17,7 @@ struct BlockwiseDropout
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ck::philox ph,
const int repeat_index,
const int total_repeats)
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
{
auto execute_dropout = [&](bool keep, DataType val) {
......@@ -28,15 +25,13 @@ struct BlockwiseDropout
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
int tid = get_thread_global_1d_id();
unsigned long long uni_subsequence =
tid * total_repeats * philox_calls + repeat_index * philox_calls;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8), (uni_subsequence + i));
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
......
......@@ -5,6 +5,7 @@
#include <iostream>
#include <vector>
#include <tuple>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
......@@ -117,7 +118,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -129,7 +129,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -69,8 +69,9 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed)
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -89,8 +90,8 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
......@@ -478,7 +479,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
......@@ -527,8 +528,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
seed_(seed)
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
......@@ -554,6 +554,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
std::cout << "seed_" << seed_ << std::endl;
std::cout << "offset_" << offset_ << std::endl;
}
void Print() const
......@@ -619,6 +625,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ushort p_dropout_in_16bits_;
GemmAccDataType p_dropout_rescale_;
unsigned long long seed_;
unsigned long long offset_;
bool is_dropout_;
};
......@@ -692,7 +699,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_);
arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......@@ -846,7 +854,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
......@@ -874,7 +882,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_dropout,
seed};
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -910,7 +918,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -938,7 +946,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_dropout,
seed);
seeds);
}
// polymorphic
......
......@@ -46,15 +46,17 @@ __global__ void
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed)
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, block_id * 4);
ck::philox ph(seed, global_thread_id, offset);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
......@@ -519,13 +521,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
seed_(seed)
c_element_op_{c_element_op}
{
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
......@@ -647,6 +648,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
}
std::vector<GroupKernelArg> group_kernel_args_;
......@@ -664,6 +668,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_;
};
......@@ -726,7 +731,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.c_element_op_,
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_);
arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......@@ -895,7 +901,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a_vec,
p_b_vec,
......@@ -911,7 +917,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_dropout,
seed};
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -932,7 +938,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(p_a_vec,
p_b_vec,
......@@ -948,7 +954,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_dropout,
seed);
seeds);
}
// polymorphic
......
......@@ -781,6 +781,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
auto n_block_data_idx_on_grid =
......@@ -875,8 +876,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if constexpr(IsDropout) // dropout
{
blockwise_dropout.ApplyDropout(
acc_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
}
// TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment