Commit 3a9dabcf authored by guangzlu's avatar guangzlu
Browse files

updated philox and pass pt3

parent d37c1d0b
...@@ -137,14 +137,23 @@ struct BlockwiseDropout ...@@ -137,14 +137,23 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
// ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
// int philox_calls_2 = tmp_size / 4;
// ushort tmp_id[tmp_size];
// for(int j = 0; j < philox_calls_2; j++){
// for(int i = 0; i < 4; i++){
// tmp_id[j * 4 + i] = element_global_1d_id + j * 8;
// }
//}
block_sync_lds(); block_sync_lds();
int tmp_index = 0; int tmp_index = 0;
...@@ -180,20 +189,35 @@ struct BlockwiseDropout ...@@ -180,20 +189,35 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value; constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 8;
int philox_calls_2 = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * 8); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * 8);
} }
for(int j = 0; j < philox_calls_2; j++)
{
for(int i = 0; i < 4; i++)
{
tmp_id[j * 4 + i] = element_global_1d_id + j * 8;
}
}
// if(get_thread_global_1d_id() == 0){
// printf("tmp_size is %d \n", tmp_size);
// //printf("n0.value is %d \n", n0.value);
//}
block_sync_lds(); block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{}; constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) { static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) = in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset)); execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value]; z_thread_buf(i) = tmp_id[i.value];
}); });
} }
......
...@@ -84,6 +84,17 @@ class philox ...@@ -84,6 +84,17 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
out[0] = static_cast<ushort>(tmp_ph.x);
out[1] = static_cast<ushort>(tmp_ph.y);
out[2] = static_cast<ushort>(tmp_ph.z);
out[3] = static_cast<ushort>(tmp_ph.w);
}
private: private:
struct ull2 struct ull2
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment