Commit d976670e authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

fixes

parent 54d73870
add_custom_target(example_grouped_gemm_xdl) add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_splitk_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// #include
// "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -36,7 +34,7 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -36,7 +34,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAdd = ck::tensor_operation::element_wise::AddAdd; using AddAdd = ck::tensor_operation::element_wise::AddAdd;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
...@@ -57,10 +55,8 @@ using BElementOp = PassThrough; ...@@ -57,10 +55,8 @@ using BElementOp = PassThrough;
using CDEElementOp = AddAdd; using CDEElementOp = AddAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDMatrices = 2; static constexpr int NumDMatrices = 2;
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// clang-format off // clang-format off
...@@ -155,9 +151,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -155,9 +151,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor( auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor( auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDMatrices> d_tens = {d0_tensor, d1_tensor}; std::array<Tensor<DDataType>, NumDMatrices> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens); d_tensors.push_back(d_tens);
...@@ -181,21 +177,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -181,21 +177,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
case 1: case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
} }
break; break;
case 2: case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
} }
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
for(int j = 0; j < NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
} }
} }
...@@ -212,19 +211,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -212,19 +211,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); {
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
} }
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
} }
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back({d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({problem_size.Ms[i], gemm_descs.push_back({problem_size.Ms[i],
problem_size.Ns[i], problem_size.Ns[i],
...@@ -234,14 +237,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -234,14 +237,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size.stride_Cs[i], problem_size.stride_Cs[i],
problem_size.stride_Ds[i]}); problem_size.stride_Ds[i]});
} }
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
//std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
...@@ -273,17 +275,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -273,17 +275,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType, using ReferenceGemmInstance =
BDataType, ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
DsDataType, BDataType,
EDataType, DsDataType,
AccDataType, EDataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
CDEElementOp>; BElementOp,
CDEElementOp>;
//float* p_workspace_dev = reinterpret_cast<float*>(gemm_workspace_dev.GetDeviceBuffer());
//std::size_t gemm_offset{0};
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
...@@ -291,18 +291,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -291,18 +291,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto dev_res_tensor = auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{})); Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{}));
// std::cout << ">>>> Copy device data back to CPU. Group id: " << i << "\n"
// << "M: " << karg.M << ", N: " << karg.N << "\n"
// << "gemm_offset: " << gemm_offset << "\n"
// << "tensor size bytes: " << dev_res_tensor.GetElementSpaceSizeInBytes() <<
// "\n"
// << std::endl;
// hip_check_error(hipMemcpy(dev_res_tensor.data(),
// p_workspace_dev + gemm_offset,
// dev_res_tensor.GetElementSpaceSizeInBytes(),
// hipMemcpyDeviceToHost));
// hip_check_error(hipDeviceSynchronize());
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(), c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
c_device_result_tensors[i].mDesc.GetElementSize() * c_device_result_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType)); sizeof(EDataType));
...@@ -318,10 +306,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -318,10 +306,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
cde_element_op); cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
//gemm_offset += argument.GetWorkspaceSize(i);
} }
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
...@@ -367,7 +352,8 @@ int main(int argc, char* argv[]) ...@@ -367,7 +352,8 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({}); problem_size.stride_Ds.push_back({});
for(int j=0; j<NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
} }
} }
...@@ -397,7 +383,8 @@ int main(int argc, char* argv[]) ...@@ -397,7 +383,8 @@ int main(int argc, char* argv[])
problem_size.stride_Bs = argToIntArray(argv[8]); problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]); problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j=0; j<NumDMatrices; ++j) { for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs); problem_size.stride_Ds.push_back(problem_size.stride_Cs);
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -33,7 +33,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator ...@@ -33,7 +33,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
{ {
Argument(const Tensor<ADataType>& a_m_k, Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n, const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -101,11 +101,16 @@ struct ReferenceGemmMultipleD : public device::BaseOperator ...@@ -101,11 +101,16 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
CDataType v_c = 0; CDataType v_c = 0;
if constexpr (DsDataType::Size() == 0) { if constexpr(DsDataType::Size() == 0)
{
arg.cde_element_op_(v_c, v_acc); arg.cde_element_op_(v_c, v_acc);
} else if constexpr(DsDataType::Size() == 1) { }
else if constexpr(DsDataType::Size() == 1)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n)); arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n));
} else if constexpr(DsDataType::Size() == 2) { }
else if constexpr(DsDataType::Size() == 2)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n)); arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n));
} }
...@@ -136,7 +141,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator ...@@ -136,7 +141,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
static auto MakeArgument(const Tensor<ADataType>& a_m_k, static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n, const std::array<Tensor<DDataType>, DsDataType::Size()>& ds_m_n,
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -157,7 +162,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator ...@@ -157,7 +162,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "ReferenceGemm" str << "ReferenceGemmMultipleD"
<< std::endl; << std::endl;
// clang-format on // clang-format on
......
...@@ -55,7 +55,7 @@ check_err(const Range& out, ...@@ -55,7 +55,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -106,7 +106,7 @@ check_err(const Range& out, ...@@ -106,7 +106,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -156,7 +156,7 @@ check_err(const Range& out, ...@@ -156,7 +156,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -211,7 +211,7 @@ check_err(const Range& out, ...@@ -211,7 +211,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl; << std::endl;
...@@ -260,7 +260,7 @@ check_err(const Range& out, ...@@ -260,7 +260,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -305,7 +305,7 @@ check_err(const Range& out, ...@@ -305,7 +305,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 50000) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
...@@ -270,16 +270,15 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, ...@@ -270,16 +270,15 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero(); c_device_buf[i]->SetZero();
std::cout << "p1\n";
invoker_ptr->Run(argument_ptr.get(), invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter}); StreamConfig{nullptr, false, 0, n_warmup, n_iter});
std::cout << "p2\n";
if(do_verification) if(do_verification)
{ {
bool instance_pass = true; bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
std::cout << "p3\n";
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1) if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
...@@ -317,10 +316,10 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, ...@@ -317,10 +316,10 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
pass = pass && instance_pass; pass = pass && instance_pass;
} }
std::cout << "p4\n";
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::cout << "p5\n";
if(time_kernel) if(time_kernel)
{ {
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment