"vscode:/vscode.git/clone" did not exist on "175ae3ab91d95791635e1d8be7880bc03e7af612"
Commit d976670e authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

fixes

parent 54d73870
add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_splitk_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
......@@ -9,8 +9,6 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// #include
// "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -59,8 +57,6 @@ using CDEElementOp = AddAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDMatrices = 2;
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// clang-format off
......@@ -181,21 +177,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
for(int j = 0; j < NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
......@@ -212,19 +211,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDMatrices; ++j) {
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back({d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({problem_size.Ms[i],
problem_size.Ns[i],
......@@ -241,7 +244,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
//std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
......@@ -273,7 +275,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
......@@ -282,27 +285,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
BElementOp,
CDEElementOp>;
//float* p_workspace_dev = reinterpret_cast<float*>(gemm_workspace_dev.GetDeviceBuffer());
//std::size_t gemm_offset{0};
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = argument.gemm_kernel_args_[i].karg_;
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{}));
// std::cout << ">>>> Copy device data back to CPU. Group id: " << i << "\n"
// << "M: " << karg.M << ", N: " << karg.N << "\n"
// << "gemm_offset: " << gemm_offset << "\n"
// << "tensor size bytes: " << dev_res_tensor.GetElementSpaceSizeInBytes() <<
// "\n"
// << std::endl;
// hip_check_error(hipMemcpy(dev_res_tensor.data(),
// p_workspace_dev + gemm_offset,
// dev_res_tensor.GetElementSpaceSizeInBytes(),
// hipMemcpyDeviceToHost));
// hip_check_error(hipDeviceSynchronize());
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
c_device_result_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
......@@ -318,10 +306,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
cde_element_op);
ref_invoker.Run(ref_argument);
// pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
//gemm_offset += argument.GetWorkspaceSize(i);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
......@@ -367,7 +352,8 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j=0; j<NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
......@@ -397,7 +383,8 @@ int main(int argc, char* argv[])
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j=0; j<NumDMatrices; ++j) {
for(int j = 0; j < NumDMatrices; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -101,11 +101,16 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
CDataType v_c = 0;
if constexpr (DsDataType::Size() == 0) {
if constexpr(DsDataType::Size() == 0)
{
arg.cde_element_op_(v_c, v_acc);
} else if constexpr(DsDataType::Size() == 1) {
}
else if constexpr(DsDataType::Size() == 1)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n));
} else if constexpr(DsDataType::Size() == 2) {
}
else if constexpr(DsDataType::Size() == 2)
{
arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n));
}
......@@ -157,7 +162,7 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
str << "ReferenceGemmMultipleD"
<< std::endl;
// clang-format on
......
......@@ -55,7 +55,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -106,7 +106,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -156,7 +156,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -211,7 +211,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
......@@ -260,7 +260,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......@@ -305,7 +305,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 50000)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
......@@ -270,16 +270,15 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
std::cout << "p1\n";
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
std::cout << "p2\n";
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
std::cout << "p3\n";
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
......@@ -317,10 +316,10 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
pass = pass && instance_pass;
}
std::cout << "p4\n";
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::cout << "p5\n";
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment