Commit 9574b34d authored by danyao12's avatar danyao12
Browse files

adjust grouped kernels interface

parent 29398e70
...@@ -645,7 +645,6 @@ int run(int argc, char* argv[]) ...@@ -645,7 +645,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -694,7 +693,6 @@ int run(int argc, char* argv[]) ...@@ -694,7 +693,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
...@@ -657,7 +657,6 @@ int run(int argc, char* argv[]) ...@@ -657,7 +657,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -707,7 +706,6 @@ int run(int argc, char* argv[]) ...@@ -707,7 +706,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
...@@ -721,8 +721,7 @@ int run(int argc, char* argv[]) ...@@ -721,8 +721,7 @@ int run(int argc, char* argv[])
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should {seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread // be at least the number of elements on a thread
...@@ -770,7 +769,6 @@ int run(int argc, char* argv[]) ...@@ -770,7 +769,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd)); DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
...@@ -820,8 +818,7 @@ int run(int argc, char* argv[]) ...@@ -820,8 +818,7 @@ int run(int argc, char* argv[])
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should {seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread // be at least the number of elements on a thread
...@@ -861,7 +858,6 @@ int run(int argc, char* argv[]) ...@@ -861,7 +858,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd)); DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
gemm_bwd.SetWorkSpacePointer(&argument_bwd, gemm_bwd.SetWorkSpacePointer(&argument_bwd,
......
...@@ -258,8 +258,7 @@ int run(int argc, char* argv[]) ...@@ -258,8 +258,7 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
...@@ -302,8 +301,7 @@ int run(int argc, char* argv[]) ...@@ -302,8 +301,7 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
......
...@@ -684,7 +684,6 @@ int run(int argc, char* argv[]) ...@@ -684,7 +684,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -733,7 +732,6 @@ int run(int argc, char* argv[]) ...@@ -733,7 +732,6 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
...@@ -280,8 +280,7 @@ int run(int argc, char* argv[]) ...@@ -280,8 +280,7 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
...@@ -336,8 +335,7 @@ int run(int argc, char* argv[]) ...@@ -336,8 +335,7 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
......
...@@ -134,7 +134,6 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -134,7 +134,6 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) = 0; std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
...@@ -1321,8 +1321,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1321,8 +1321,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -1353,8 +1353,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1353,8 +1353,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -1180,8 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1180,8 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -1213,8 +1213,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1213,8 +1213,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -1024,8 +1024,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1024,8 +1024,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 && if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
b_g <= c_g))
{ {
return false; return false;
} }
......
...@@ -930,15 +930,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -930,15 +930,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop}, p_dropout_{p_drop}
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -972,6 +970,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -972,6 +970,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d_grid_size_ = 0; d_grid_size_ = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -1453,7 +1454,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1453,7 +1454,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1478,7 +1478,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1478,7 +1478,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1509,7 +1508,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1509,7 +1508,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1534,7 +1532,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1534,7 +1532,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -999,15 +999,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -999,15 +999,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop}, p_dropout_{p_drop}
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -1041,6 +1039,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1041,6 +1039,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d_grid_size_ = 0; d_grid_size_ = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -1527,7 +1528,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1527,7 +1528,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1552,7 +1552,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1552,7 +1552,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1583,7 +1582,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1583,7 +1582,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1608,7 +1606,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1608,7 +1606,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -812,15 +812,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -812,15 +812,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop}, p_dropout_{p_drop}
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -851,6 +849,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -851,6 +849,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -1297,7 +1298,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1297,7 +1298,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1321,7 +1321,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1321,7 +1321,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1351,7 +1350,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1351,7 +1350,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1375,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1375,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -882,15 +882,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -882,15 +882,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
p_dropout_{p_drop}, p_dropout_{p_drop}
h_ratio_{h_ratio}
{ {
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -921,6 +919,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -921,6 +919,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -1372,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1372,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, return Argument{p_As,
...@@ -1396,7 +1396,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1396,7 +1396,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds}; seeds};
} }
...@@ -1426,7 +1425,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1426,7 +1425,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
...@@ -1450,7 +1448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1450,7 +1448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
h_ratio,
seeds); seeds);
} }
......
...@@ -682,14 +682,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -682,14 +682,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op}
h_ratio_{h_ratio}
{ {
ignore = p_acc1_biases_vec; ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
...@@ -708,6 +706,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -708,6 +706,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
...@@ -1214,7 +1215,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1214,7 +1215,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a_vec, return Argument{p_a_vec,
...@@ -1232,7 +1232,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1232,7 +1232,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
h_ratio,
seeds}; seeds};
} }
...@@ -1255,7 +1254,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1255,7 +1254,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) override std::tuple<unsigned long long, unsigned long long> seeds) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
...@@ -1273,7 +1271,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1273,7 +1271,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_dropout, p_dropout,
h_ratio,
seeds); seeds);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment