Commit ab1cd337 authored by danyao12's avatar danyao12
Browse files

fwd scale sync with bwd

parent a606c5e7
......@@ -1033,8 +1033,6 @@ int run(int argc, char* argv[])
std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
// vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
// 5 GEMM ops in total:
......@@ -1143,8 +1141,6 @@ int run(int argc, char* argv[])
fwd_file << z_fwd_gs_ms_ns << std::endl;
qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
......
......@@ -20,15 +20,10 @@ int run(int argc, char* argv[])
ck::index_t G0 = 7;
ck::index_t G1 = 13;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.1;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -55,7 +50,7 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
......@@ -71,6 +66,11 @@ int run(int argc, char* argv[])
exit(0);
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment