"...composable_kernel.git" did not exist on "f99f419dfd07cbad25d8068cf7fed84ab501f0b3"
Commit 9241ff62 authored by guangzlu's avatar guangzlu
Browse files

modified method to set offset in philox

parent 11eed39f
...@@ -178,7 +178,8 @@ int run(int argc, char* argv[]) ...@@ -178,7 +178,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
0); // dropout ratio 0, // dropout ratio
{0, 64}); // dropout random seed and offset
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -218,7 +218,8 @@ int run(int argc, char* argv[]) ...@@ -218,7 +218,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
0); // dropout ratio 0, // dropout ratio
{0, 448}); // dropout random seed and offset
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......
...@@ -17,10 +17,7 @@ struct BlockwiseDropout ...@@ -17,10 +17,7 @@ struct BlockwiseDropout
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
ck::philox ph,
const int repeat_index,
const int total_repeats)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -28,15 +25,13 @@ struct BlockwiseDropout ...@@ -28,15 +25,13 @@ struct BlockwiseDropout
}; };
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
int tid = get_thread_global_1d_id(); int philox_calls = tmp_size / 8;
unsigned long long uni_subsequence =
tid * total_repeats * philox_calls + repeat_index * philox_calls;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), (uni_subsequence + i)); ph.get_random_8x16((tmp + i * 8));
} }
block_sync_lds(); block_sync_lds();
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <tuple>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
...@@ -117,7 +118,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -117,7 +118,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) = 0; std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -129,7 +129,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -129,7 +129,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) = 0; std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -69,8 +69,9 @@ __global__ void ...@@ -69,8 +69,9 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed) const unsigned long long seed,
const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -89,8 +90,8 @@ __global__ void ...@@ -89,8 +90,8 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t block_id = get_block_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, block_id * 4); ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
...@@ -478,7 +479,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -478,7 +479,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
unsigned long long seed) std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -527,8 +528,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -527,8 +528,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_, b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
seed_(seed)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -554,6 +554,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -554,6 +554,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_); p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
std::cout << "seed_" << seed_ << std::endl;
std::cout << "offset_" << offset_ << std::endl;
} }
void Print() const void Print() const
...@@ -619,6 +625,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -619,6 +625,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ushort p_dropout_in_16bits_; ushort p_dropout_in_16bits_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_;
bool is_dropout_; bool is_dropout_;
}; };
...@@ -692,7 +699,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -692,7 +699,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_); arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...@@ -846,7 +854,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -846,7 +854,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -874,7 +882,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -874,7 +882,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
seed}; seeds};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -910,7 +918,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -910,7 +918,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) override std::tuple<unsigned long long, unsigned long long> seeds) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -938,7 +946,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -938,7 +946,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
seed); seeds);
} }
// polymorphic // polymorphic
......
...@@ -46,15 +46,17 @@ __global__ void ...@@ -46,15 +46,17 @@ __global__ void
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed) const unsigned long long seed,
const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, block_id * 4); ck::philox ph(seed, global_thread_id, offset);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -519,13 +521,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -519,13 +521,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
unsigned long long seed) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op}
seed_(seed)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
...@@ -647,6 +648,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -647,6 +648,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_); p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
} }
std::vector<GroupKernelArg> group_kernel_args_; std::vector<GroupKernelArg> group_kernel_args_;
...@@ -664,6 +668,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -664,6 +668,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; ushort p_dropout_in_16bits_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool is_dropout_;
}; };
...@@ -726,7 +731,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -726,7 +731,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.c_element_op_, arg.c_element_op_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_); arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...@@ -895,7 +901,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -895,7 +901,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a_vec, return Argument{p_a_vec,
p_b_vec, p_b_vec,
...@@ -911,7 +917,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -911,7 +917,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
seed}; seeds};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -932,7 +938,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -932,7 +938,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
const unsigned long long seed = 0) override std::tuple<unsigned long long, unsigned long long> seeds) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
p_b_vec, p_b_vec,
...@@ -948,7 +954,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -948,7 +954,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
seed); seeds);
} }
// polymorphic // polymorphic
......
...@@ -781,6 +781,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -781,6 +781,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
...@@ -875,8 +876,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -875,8 +876,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
blockwise_dropout.ApplyDropout( blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
acc_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
} }
// TODO: may convert to log domain // TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment