Commit 89c46e79 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed test_grouped_gemm

parent bd0f0686
...@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int group_count = rand() % 10 + 1; int group_count = rand() % 10 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b; std::vector<const void*> p_a, p_b;
std::vector<void*> p_c; std::vector<void*> p_c;
gemm_shapes.reserve(group_count); gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
...@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K; int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M; int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M;
gemm_shapes.push_back({M, N, K, AStride, BStride, CStride}); gemm_descs.push_back({M, N, K, AStride, BStride, CStride});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count);
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
} }
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors_device.emplace_back( a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
...@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get())); DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
...@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment