"vscode:/vscode.git/clone" did not exist on "381dd9feb4a1c5b900a779d1eb05938f53fb8865"
Unverified Commit fe6ce55c authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Grouped gemm test fix (#150)

* fixed test: return res; rand gemm shapes

* fixed return
parent 313bbea5
......@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
{
int group_count = 4;
int group_count = rand() % 10 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
......@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 128 + 64 * i;
int M = 256 + 256 * (rand() % 10);
int N = 256 + 256 * (rand() % 10);
int K = 128 + 128 * (rand() % 10);
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
......@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
}
for(int i = 0; i < gemm_shapes.size(); i++)
......@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_element_op,
c_element_op);
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
{
return false;
}
ref_invoker.Run(ref_argument);
bool res = check_err(c_device_tensors[i], c_host_tensors[i]);
......@@ -210,4 +215,6 @@ int main()
}
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment