Commit ab1cd337 authored by danyao12's avatar danyao12
Browse files

fwd scale sync with bwd

parent a606c5e7
...@@ -1033,8 +1033,6 @@ int run(int argc, char* argv[]) ...@@ -1033,8 +1033,6 @@ int run(int argc, char* argv[])
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero(); qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
// vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true}); float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
// 5 GEMM ops in total: // 5 GEMM ops in total:
...@@ -1143,8 +1141,6 @@ int run(int argc, char* argv[]) ...@@ -1143,8 +1141,6 @@ int run(int argc, char* argv[])
fwd_file << z_fwd_gs_ms_ns << std::endl; fwd_file << z_fwd_gs_ms_ns << std::endl;
qgrad_device_buf.SetZero(); qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
......
...@@ -20,15 +20,10 @@ int run(int argc, char* argv[]) ...@@ -20,15 +20,10 @@ int run(int argc, char* argv[])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1 = 13;
float alpha = 1;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1; float p_drop = 0.1;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -55,7 +50,7 @@ int run(int argc, char* argv[]) ...@@ -55,7 +50,7 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
...@@ -71,6 +66,11 @@ int run(int argc, char* argv[]) ...@@ -71,6 +66,11 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment