"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "e3a24b79a4b45b4032a27c0cb09b6715af3422c5"
Unverified Commit fe6ce55c authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Grouped gemm test fix (#150)

* fixed test: return res; rand gemm shapes

* fixed return
parent 313bbea5
...@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
{ {
int group_count = 4; int group_count = rand() % 10 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
...@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int M = 256 + 256 * i; int M = 256 + 256 * (rand() % 10);
int N = 128 + 128 * i; int N = 256 + 256 * (rand() % 10);
int K = 128 + 64 * i; int K = 128 + 128 * (rand() % 10);
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M; int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K; int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
...@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
} }
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
...@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_element_op, b_element_op,
c_element_op); c_element_op);
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
{
return false;
}
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
bool res = check_err(c_device_tensors[i], c_host_tensors[i]); bool res = check_err(c_device_tensors[i], c_host_tensors[i]);
...@@ -210,4 +215,6 @@ int main() ...@@ -210,4 +215,6 @@ int main()
} }
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment