Commit b9cb659d authored by guangzlu's avatar guangzlu
Browse files

added time test for run_grouped_multihead_attention_forward.inc

parent a2b5c7ac
...@@ -5,12 +5,11 @@ int run(int argc, char* argv[]) ...@@ -5,12 +5,11 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
...@@ -56,7 +55,8 @@ int run(int argc, char* argv[]) ...@@ -56,7 +55,8 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<void*> p_z; std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
...@@ -221,6 +221,7 @@ int run(int argc, char* argv[]) ...@@ -221,6 +221,7 @@ int run(int argc, char* argv[])
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -233,12 +234,13 @@ int run(int argc, char* argv[]) ...@@ -233,12 +234,13 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = auto argument =
gemm.MakeArgument(p_a, gemm.MakeArgument(p_a,
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
p_z, p_z_nullptr,
p_lse, p_lse,
{}, // p_acc0_biases {}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
...@@ -252,7 +254,6 @@ int run(int argc, char* argv[]) ...@@ -252,7 +254,6 @@ int run(int argc, char* argv[])
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -277,6 +278,31 @@ int run(int argc, char* argv[]) ...@@ -277,6 +278,31 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
argument =
gemm.MakeArgument(p_a,
p_b0,
p_b1,
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1_m_n_k_o[i][0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment