Commit 9574b34d authored by danyao12's avatar danyao12
Browse files

adjust grouped kernels interface

parent 29398e70
......@@ -645,7 +645,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -694,7 +693,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
......@@ -657,7 +657,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -707,7 +706,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
......@@ -721,8 +721,7 @@ int run(int argc, char* argv[])
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread
......@@ -770,7 +769,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
......@@ -820,8 +818,7 @@ int run(int argc, char* argv[])
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread
......@@ -861,7 +858,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
gemm_bwd.SetWorkSpacePointer(&argument_bwd,
......
......@@ -258,8 +258,7 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
......@@ -302,8 +301,7 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
......
......@@ -684,7 +684,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -733,7 +732,6 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
......
......@@ -280,8 +280,7 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
......@@ -336,8 +335,7 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
h_ratio,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
......
......@@ -134,7 +134,6 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
......@@ -1321,8 +1321,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{
return false;
}
......
......@@ -1353,8 +1353,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{
return false;
}
......
......@@ -1180,8 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{
return false;
}
......
......@@ -1213,8 +1213,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{
return false;
}
......
......@@ -1024,8 +1024,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
b_g <= c_g))
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{
return false;
}
......
......@@ -930,15 +930,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
p_dropout_{p_drop},
h_ratio_{h_ratio}
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -972,6 +970,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d_grid_size_ = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -1453,7 +1454,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_As,
......@@ -1478,7 +1478,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds};
}
......@@ -1509,7 +1508,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(p_As,
......@@ -1534,7 +1532,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds);
}
......
......@@ -999,15 +999,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
p_dropout_{p_drop},
h_ratio_{h_ratio}
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -1041,6 +1039,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d_grid_size_ = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -1527,7 +1528,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_As,
......@@ -1552,7 +1552,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds};
}
......@@ -1583,7 +1582,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(p_As,
......@@ -1608,7 +1606,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds);
}
......
......@@ -812,15 +812,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
p_dropout_{p_drop},
h_ratio_{h_ratio}
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -851,6 +849,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -1297,7 +1298,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_As,
......@@ -1321,7 +1321,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds};
}
......@@ -1351,7 +1350,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(p_As,
......@@ -1375,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds);
}
......
......@@ -882,15 +882,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
p_dropout_{p_drop},
h_ratio_{h_ratio}
p_dropout_{p_drop}
{
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -921,6 +919,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -1372,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_As,
......@@ -1396,7 +1396,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds};
}
......@@ -1426,7 +1425,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(p_As,
......@@ -1450,7 +1448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_drop,
h_ratio,
seeds);
}
......
......@@ -682,14 +682,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
h_ratio_{h_ratio}
c_element_op_{c_element_op}
{
ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition
......@@ -708,6 +706,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
......@@ -1214,7 +1215,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a_vec,
......@@ -1232,7 +1232,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_dropout,
h_ratio,
seeds};
}
......@@ -1255,7 +1254,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
index_t h_ratio,
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(p_a_vec,
......@@ -1273,7 +1271,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op,
c_element_op,
p_dropout,
h_ratio,
seeds);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment