Commit 1d7099b6 authored by danyao12's avatar danyao12
Browse files

fix hd256 dropout scratch

parent a67bdd63
......@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
......@@ -53,6 +53,22 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private:
struct ull2
{
......
......@@ -333,19 +333,26 @@ struct BlockDropout
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
uint8_t* random_uint8_t_ = random_uint8_t;
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
uint8_t* random_uint8_t_;
if constexpr(!IsWG32)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
int start_idx = ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
uint32_t* random_uint32_t = reinterpret_cast<uint32_t*>(random_uint8_t);
random_uint8_t_ = reinterpret_cast<uint8_t*>(&random_uint32_t[start_idx]);
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment