Commit c56f28b0 authored by danyao12's avatar danyao12
Browse files

let grouped fwd test more random

parent 62357796
......@@ -231,7 +231,8 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// NumDimG,
// NumDimM,
......
......@@ -380,7 +380,8 @@ using DeviceGemmInstanceBWD =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// NumDimG,
// NumDimM,
......
......@@ -10,10 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.1;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float p_drop = 0.2;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -27,14 +24,15 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
input_permute = std::stoi(argv[4]);
output_permute = std::stoi(argv[5]);
p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]);
}
else
{
......@@ -45,6 +43,10 @@ int run(int argc, char* argv[])
exit(0);
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1; // scaling after 1st gemm
std::size_t group_count = 8;
......@@ -81,14 +83,14 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1);
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
#if(RANGE_HDKO == 0)
int K = 32; // K/O<=32
#elif(RANGE_HDKO == 1)
int K = 64; // 32<K/O<=64
int K = 56; // 32<K/O<=64
#elif(RANGE_HDKO == 2)
int K = 72; // 64<K/O<=128
int K = 80; // 64<K/O<=128
#endif
int O = K;
int G0 = rand() % 3 + 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment