Commit 1d7099b6 authored by danyao12's avatar danyao12
Browse files

fix hd256 dropout scratch

parent a67bdd63
...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); ...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64))); using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
...@@ -53,6 +53,22 @@ class philox ...@@ -53,6 +53,22 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private: private:
struct ull2 struct ull2
{ {
......
...@@ -333,19 +333,26 @@ struct BlockDropout ...@@ -333,19 +333,26 @@ struct BlockDropout
uint2 rowcol = make_uint2(block_row_start, block_col_start); uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number // generate random number
uint8_t random_uint8_t[16]; uint8_t* random_uint8_t_;
uint8_t* random_uint8_t_ = random_uint8_t;
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
if constexpr(!IsWG32) if constexpr(!IsWG32)
{ {
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0 // m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1 // m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2 // m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3 // m1t16~m1t31/m1t48~m1t63: 3
int start_idx = ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1); const index_t start_idx =
uint32_t* random_uint32_t = reinterpret_cast<uint32_t*>(random_uint8_t); ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
random_uint8_t_ = reinterpret_cast<uint8_t*>(&random_uint32_t[start_idx]); ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
} }
constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment