Commit 3a9ab7a7 authored by guangzlu's avatar guangzlu
Browse files

changed random number generate into hiprand

parent 7c4c31cf
...@@ -7,6 +7,8 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") ...@@ -7,6 +7,8 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
enable_testing() enable_testing()
add_definitions(-w)
set(ROCM_SYMLINK_LIBS OFF) set(ROCM_SYMLINK_LIBS OFF)
find_package(ROCM REQUIRED PATHS /opt/rocm) find_package(ROCM REQUIRED PATHS /opt/rocm)
......
...@@ -49,7 +49,8 @@ using B1DataType = DataType; ...@@ -49,7 +49,8 @@ using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; //using ZDataType = U16;
using ZDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
......
...@@ -27,7 +27,7 @@ int run(int argc, char* argv[]) ...@@ -27,7 +27,7 @@ int run(int argc, char* argv[])
float p_drop = 0.1; float p_drop = 0.1;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_float = p_dropout;//uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -322,7 +322,7 @@ int run(int argc, char* argv[]) ...@@ -322,7 +322,7 @@ int run(int argc, char* argv[])
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument( auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout); z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_float, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment); ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1 // gemm1
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
#pragma once #pragma once
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/philox_rand.hpp" #include "ck/utility/philox_rand.hpp"
...@@ -11,6 +14,23 @@ namespace ck { ...@@ -11,6 +14,23 @@ namespace ck {
template <typename DataType, typename ThreadSliceDesc_M_K> template <typename DataType, typename ThreadSliceDesc_M_K>
struct BlockwiseDropout struct BlockwiseDropout
{ {
//__host__ __device__ BlockwiseDropout(){}
//
//__host__ __device__ BlockwiseDropout(ushort p_dropout_in_16bits, DataType p_dropout_to_rescale)
//{
// p_dropout_16bits = p_dropout_in_16bits;
// p_dropout_rescale = p_dropout_to_rescale;
//}
//
//__host__ __device__ BlockwiseDropout(float p_dropout_in_float, DataType p_dropout_to_rescale)
//{
// p_dropout_float = p_dropout_in_float;
// p_dropout_rescale = p_dropout_to_rescale;
//}
//~BlockwiseDropout(){}
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
...@@ -50,6 +70,48 @@ struct BlockwiseDropout ...@@ -50,6 +70,48 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, hiprandState_t& state)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int hiprand_calls = tmp_size / 8;
float tmp[tmp_size];
for(int i = 0; i < hiprand_calls; i++)
{
float tmp_rand = hiprand_uniform(&state);
tmp[i] = tmp_rand;
tmp[i+1] = tmp_rand;
tmp[i+2] = tmp_rand;
tmp[i+3] = tmp_rand;
tmp[i+4] = tmp_rand;
tmp[i+5] = tmp_rand;
tmp[i+6] = tmp_rand;
tmp[i+7] = tmp_rand;
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_float, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf) ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
...@@ -122,7 +184,44 @@ struct BlockwiseDropout ...@@ -122,7 +184,44 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, hiprandState_t& state, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
tmp[i] = hiprand_uniform(&state);
}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_float, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
float p_dropout_float;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp" //#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
...@@ -74,7 +77,7 @@ __global__ void ...@@ -74,7 +77,7 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const float p_dropout_in_float,
const GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
...@@ -99,7 +102,9 @@ __global__ void ...@@ -99,7 +102,9 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); hiprandState_t state;
hiprand_init(seed, global_thread_id, offset, &state);
//ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
...@@ -122,9 +127,9 @@ __global__ void ...@@ -122,9 +127,9 @@ __global__ void
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits, p_dropout_in_float,
p_dropout_rescale, p_dropout_rescale,
ph); state);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -591,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -591,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_float_ = p_dropout_ ; //uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_); p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
...@@ -673,7 +678,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -673,7 +678,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; ushort p_dropout_in_float_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
...@@ -757,7 +762,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -757,7 +762,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_float_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
......
...@@ -3,8 +3,11 @@ ...@@ -3,8 +3,11 @@
#pragma once #pragma once
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp" //#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -443,9 +446,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -443,9 +446,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits, const float p_dropout_in_float,
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox ph) hiprandState_t state)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -792,7 +795,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -792,7 +795,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{ auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale}; 0, p_dropout_in_float, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
...@@ -1013,7 +1016,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1013,7 +1016,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false, false,
decltype(n0), decltype(n0),
decltype(i)>( decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer); acc_thread_buf, state, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1037,7 +1040,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1037,7 +1040,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// ignore = z_grid_buf; // ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, state);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment