Commit 26115ce7 authored by guangzlu's avatar guangzlu
Browse files

fixed blockwise_dropout.hpp

parent ce31e22a
...@@ -64,12 +64,12 @@ struct BlockwiseDropout ...@@ -64,12 +64,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
block_sync_lds(); block_sync_lds();
...@@ -140,20 +140,11 @@ struct BlockwiseDropout ...@@ -140,20 +140,11 @@ struct BlockwiseDropout
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
// ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
// int philox_calls_2 = tmp_size / 4;
// ushort tmp_id[tmp_size];
// for(int j = 0; j < philox_calls_2; j++){
// for(int i = 0; i < 4; i++){
// tmp_id[j * 4 + i] = element_global_1d_id + j * 8;
// }
//}
block_sync_lds(); block_sync_lds();
int tmp_index = 0; int tmp_index = 0;
...@@ -168,59 +159,6 @@ struct BlockwiseDropout ...@@ -168,59 +159,6 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8;
int philox_calls_2 = tmp_size / 4;
ushort tmp[tmp_size];
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * 8);
}
for(int j = 0; j < philox_calls_2; j++)
{
for(int i = 0; i < 4; i++)
{
tmp_id[j * 4 + i] = element_global_1d_id + j * 8;
}
}
// if(get_thread_global_1d_id() == 0){
// printf("tmp_size is %d \n", tmp_size);
// //printf("n0.value is %d \n", n0.value);
//}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp_id[i.value];
});
}
template <typename CThreadBuffer, template <typename CThreadBuffer,
typename ZThreadBuffer, typename ZThreadBuffer,
bool using_sign_bit, bool using_sign_bit,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment