Commit f0759faf authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 20ddaeba 764164b4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Multiply;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using DsDataType = ck::Tuple<B1DataType, D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<B1Layout, D0Layout>;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyAddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 1;
constexpr ck::index_t NumDTensor = 2;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{b1_device_buf.GetDeviceBuffer(),
d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB},
std::array<ck::index_t, NumDTensor>{0, StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
#if 0
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <set>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/flush_icache.hpp"
namespace ck {
namespace utility {
template <typename Argument>
struct RotatingMemWrapper
{
using ADataType = decltype(Argument::p_a_grid);
using BDataType = decltype(Argument::p_b_grid);
RotatingMemWrapper() = delete;
RotatingMemWrapper(Argument& arg_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_)
: arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
{
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
}
}
void Print()
{
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapper()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
};
inline void flush_icache()
{
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
hip_check_error(hipGetLastError());
}
// if TimePrePress == false, return time does not include preprocess's time
template <bool TimePreprocess, typename Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args& args)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
if(stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif
// warm up
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
}
const int nrepeat = stream_config.nrepeat_;
if(nrepeat == 0)
{
return 0.0;
}
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
#if MEDIAN
std::set<float> times;
#else
float total_time = 0;
#endif
for(int i = 0; i < nrepeat; ++i)
{
if constexpr(!TimePreprocess)
{
preprocess();
}
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if constexpr(TimePreprocess)
{
preprocess();
}
// run real kernel
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
// end real kernel
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
#if DEBUG_LOG
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("args.p_a_grid: %p, args.p_b_grid:%p\n",
static_cast<const void*>(args.p_a_grid),
static_cast<const void*>(args.p_b_grid));
#endif
}
#if MEDIAN
auto mid = times.begin();
std::advance(mid, (nrepeat - 1) / 2);
if(nrepeat % 2 == 1)
{
return *mid;
}
else
{
auto mid_next = mid;
std::advance(mid_next, 1);
return (*mid + *mid_next) / 2;
}
#else
return total_time / nrepeat;
#endif
}
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
return 0;
#endif
}
} // namespace utility
} // namespace ck
......@@ -13,4 +13,7 @@ struct StreamConfig
int log_level_ = 0;
int cold_niters_ = 5;
int nrepeat_ = 50;
bool flush_cache = false;
int rotating_count = 1;
};
......@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / (4 * warpSize / BlockSize),
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2
......@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / (4 * warpSize / BlockSize),
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2
......
......@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto ds_read_b_issue_cycle =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / ds_read_a_issue_cycle;
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / ds_read_b_issue_cycle;
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
sizeof(ComputeDataType) / sizeof(BDataType)
? sizeof(ComputeDataType) / sizeof(ADataType)
: sizeof(ComputeDataType) / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - num_mfma_per_ds_read * (num_ds_read_inst_a / ds_read_a_mfma_rate +
num_ds_read_inst_b / ds_read_b_mfma_rate);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
......@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
// stage 2
static_for<0, num_ds_read_inst_a / ds_read_a_mfma_rate, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_ds_read_inst_b / ds_read_b_mfma_rate, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
......
......@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto ds_read_b_issue_cycle =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
......
......@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags>
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r2
{
static constexpr index_t nDim =
......@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template <typename SrcBuffers>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs);
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
......@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmTileLoopKernelArguments
{
__host__ __device__
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim))
{
return false;
......@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim))
{
return false;
......@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_ds_access = false;
......@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
......
......@@ -92,15 +92,6 @@ struct Add
};
};
struct Scales
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
}
};
struct Max
{
template <typename Y, typename X0, typename X1>
......@@ -188,6 +179,16 @@ struct Multiply
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
......@@ -521,6 +522,71 @@ struct AddFastGelu
}
};
// E = MultiplyFastGelu(C + D)
struct MultiplyFastGelu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c * d;
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c * d;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c * d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
{
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
// E = Silu(C + D)
struct AddSilu
{
......
......@@ -221,6 +221,15 @@ struct MultiplyAdd
e = y;
}
template <>
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
const float& c,
const bhalf_t& d0,
const bhalf_t& d1) const
{
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
......@@ -240,6 +249,26 @@ struct MultiplyAdd
}
};
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
{
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = ck::type_convert<ck::bhalf_t>(x1_f);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
......@@ -499,6 +528,26 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct ConvInvscale
{
/// @brief Op to multiply convolution results by inverted scale factors
/// @param e Output after scaling
/// @param c Convolution result
/// @param d0 Input scale factor
/// @param d1 Weights scale factor
/// @param d2 Output scale factor
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <>
__host__ __device__ void operator()<f8_t, float, float, float, float>(
f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
e = type_convert<f8_t>(c / d0 / d1 / d2);
};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -504,6 +504,16 @@ struct FastGelu
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<bhalf_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
......
......@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
......@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, false>,
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
......@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
......@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, false>,
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
......
......@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename ALayout, typename BLayout, typename ELayout>
__host__ __device__ static bool
CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw)
{
// Check if the vector dim is K1 or M|N
const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw;
const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw;
const auto E_vector_dim_size = NRaw;
// check vector load for A tensor
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
if(!(A_vector_dim_size == KRaw &&
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
if(!(A_vector_dim_size == MRaw &&
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
return false;
}
else
{
return false;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
if(!(B_vector_dim_size == NRaw &&
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if(!(B_vector_dim_size == KRaw &&
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
return false;
}
else
{
return false;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ELayout>)
{
if(!(E_vector_dim_size == NRaw &&
E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
return false;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ELayout>)
{
if(!(E_vector_dim_size == NRaw &&
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
return false;
}
else
{
return false;
}
return true;
}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
......@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap&)
[[maybe_unused]] const Block2ETileMap&)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check gridwise gemm pipeline
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
......@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
template <bool HasMainKBlockLoop,
typename AGridDesc_MK,
typename BGridDesc_NK,
typename DsGridDesc_MN,
typename EGridDesc_MN,
typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_MK& a_grid_desc_m_k,
const BGridDesc_NK& b_grid_desc_n_k,
const DsGridDesc_MN& ds_grid_desc_m_n,
const EGridDesc_MN& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_MN{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment