Commit 635b5904 authored by letaoqin's avatar letaoqin
Browse files

start

parent ceaed8e0
set(GEMM_BIAS_ADD_SOURCES
gemm_bias_add_xdl_fp16.cpp
gemm_bias_add_fp16.cpp
)
add_executable(example_gemm_bias_add_xdl_fp16 ${GEMM_BIAS_ADD_SOURCES})
target_link_libraries(example_gemm_bias_add_xdl_fp16 PRIVATE utility)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
struct GemmBiasAddArgs
{
const void* mat_a;
const void* mat_b;
const void* mat_bias;
void* mat_c;
ck::index_t M;
ck::index_t N;
ck::index_t K;
};
float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_bias_add.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using AccDataType = F32;
using CShuffleDataType = F32;
using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using CLayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType>
using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
Row, Col, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
64,
16, 16, 64,
8, 8,
16, 16,
1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1,
S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType>
using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
Row, Col, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
64,
16, 16, 64,
8, 8,
16, 16,
1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 1, 1, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 1, 1, 0,
1, 1,
S<1, 16, 1, 4>, S<1, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// clang-format on
float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config)
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "gemm_bias_add_fp16: {"
<< "mat_a: " << args.mat_a << ", mat_b: " << args.mat_b
<< ", mat_bias: " << args.mat_bias << ", mat_c: " << args.mat_c
<< ", M: " << args.M << ", N: " << args.N << ", K: " << args.K << "}"
<< std::endl;
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
ck::index_t StrideA = args.K;
ck::index_t StrideB = args.K;
ck::index_t StrideD = 0;
ck::index_t StrideC = args.N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
float ave_time = 0;
auto Run = [&](auto& gemm) {
auto argument = gemm.MakeArgument(args.mat_a,
args.mat_b,
std::array<const void*, NumDTensor>{args.mat_bias},
args.mat_c,
args.M,
args.N,
args.K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD},
StrideC,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
return false;
}
auto invoker = gemm.MakeInvoker();
ave_time = invoker.Run(argument, config);
return true;
};
auto gemm = DeviceOpInstance_64_16_16_64<ADataType, BDataType, DsDataType, CDataType>{};
if(!Run(gemm))
{
auto gemm_def = DeviceOpInstance_default<ADataType, BDataType, DsDataType, CDataType>{};
Run(gemm_def);
}
return ave_time;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "gemm_bias_add.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
using A0DataType = F16;
using B0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
void RunUnfusedTest(const std::vector<ck::half_t>& mat_A,
const std::vector<ck::half_t>& mat_B,
const std::vector<ck::half_t>& mat_C,
std::vector<ck::half_t>& mat_D,
int K,
int M,
int N)
{
for(int m = 0; m < M; m++)
{
std::vector<float> tmp;
for(int n = 0; n < N; n++)
{
float psum = 0.f;
for(int k = 0; k < K; k++)
{
float areg = float(mat_A[m * K + k]);
float breg = float(mat_B[n * K + k]);
psum += areg * breg;
}
psum += ck::type_convert<float>(mat_C[n]);
mat_D[m * N + n] = ck::type_convert<ck::half_t>(psum);
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 512;
ck::index_t N = 1024;
ck::index_t K = 256;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
GemmBiasAddArgs gemm_args{a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
d0_device_buf.GetDeviceBuffer(),
e_device_buf.GetDeviceBuffer(),
M,
N,
K};
float ave_time = gemm_bias_add_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50});
// float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
RunUnfusedTest(a0_m_k.mData, b0_k_n.mData, d0_m_n.mData, e_m_n_host_result.mData, K, M, N);
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment