Unverified Commit 12235112 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

external api for gemm + layernorm (#285)

* Extract base class for elementwise

* Refactor interface of DeviceGemmReduce. Do not use tuple in interface

* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add

* Unify base class of gemm + reduce and gemm + bias + add + reduce

* 1. Rename gemm_bias_add_reduce for external api
 2. Refine cmake

* Add normalize device operation

* [What] Reorder the argument
[Why] Because d0 is also the input of c.

* Add type string

* Add example of gemm_bias_add_layernorm  via external api

* Refactor example code

* clang-format

* Fix compile error

* clang-format

* Add external api for gemm_add_add_layernorm and normalize

* Add client example

* clang-format
parent aebd211c
add_executable(gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
target_link_libraries(gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using BiasDataType = F32;
using CDataType = F16;
using D0DataType = F16;
using ReduceDataType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <typename gemm_reduce_op_ptr>
bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op,
const void* p_a,
const void* p_b,
const void* p_bias,
const void* p_d0,
void* p_c,
void* p_mean,
void* p_square_mean,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int StrideD0,
bool time_kernel)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
auto passOp = PassThrough{};
auto squareOp = UnarySquareElementOp{};
auto divOp = UnaryDivElementOp{N};
auto argument_ptr =
p_op->MakeArgumentPointer(p_a,
p_b,
p_bias,
{p_d0},
p_c,
{p_mean, p_square_mean},
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
{&passOp, &passOp, &passOp}, // functor for a, b, c
{&passOp}, // functor for d0
{&passOp, &squareOp}, // functor for inputs of reduction
{&divOp, &divOp}); // functor for outputs of reduction
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
// If we evaluate running time of gemm_reduce. The output may wrong.
// Because we need to initialize the reduction tensor before runing the kernel.
// However we run kernel many times for time_kernel = trie without reinitialize the out
// of reduction tensor.
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
template <typename normalize_op_ptr>
bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
const void* p_x,
const void* p_mean,
const void* p_square_mean,
const void* p_gamma,
const void* p_beta,
void* p_y,
int M,
int N,
int StrideX,
bool time_kernel)
{
std::array<const void*, 5> input = {p_x, p_mean, p_square_mean, p_gamma, p_beta};
std::array<void*, 1> output = {p_y};
auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
auto argument_ptr = p_op->MakeArgumentPointer(input,
output,
{M, N},
{{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{{StrideX, 1}},
ck::tensor_operation::element_wise::Normalize{});
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideD0 = 1024;
const auto gemm_reduce_ptrs = ck::tensor_operation::device::device_gemm_instance::
get_device_gemm_add_add_mean_squaremean_instances<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD0, CLayout{}));
SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
bool b_time_kernel = true;
bool b_only_run_first_kernel = true;
// layernorm => (1) + (2)
// (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c)
// (2). normalize(c, mean, square_mean, gamma, beta)
for(auto& gemm_reduce_ptr : gemm_reduce_ptrs)
{
// run first available kernel
if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr,
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
d0_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideD0,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
for(auto& normalize_ptr : normalize_ptrs)
{
if(RunDeviceNormalize2D(normalize_ptr,
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
layerNorm_device_buf.GetDeviceBuffer(),
M,
N,
StrideC,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << normalize_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
\ No newline at end of file
...@@ -7,3 +7,4 @@ find_package(hip REQUIRED PATHS /opt/rocm) ...@@ -7,3 +7,4 @@ find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(02_gemm_add_add_fastgelu) add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm)
...@@ -33,19 +33,19 @@ using BDataType = F16; ...@@ -33,19 +33,19 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F64; using ReduceDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using DsReduceOp = ck::Tuple<ck::reduce::Max>; using ReduceOps = ck::Tuple<ck::reduce::Max>;
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>; using ReduceElementOps = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
...@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization = ...@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp, BElementOp,
CElementOp>; CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType> template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M; sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
...@@ -148,17 +148,17 @@ int main(int argc, char* argv[]) ...@@ -148,17 +148,17 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_host_result( Tensor<ReduceDataType> reduce_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_device_result( Tensor<ReduceDataType> reduce_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -176,35 +176,40 @@ int main(int argc, char* argv[]) ...@@ -176,35 +176,40 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
reduce_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto ds_element_op = DsElementOp{}; auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}];
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer())); std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
std::array<void*, 1> reduce_element_ops = {&reduce_element_op};
std::array<void*, 1> p_reduces = {reduce_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), nullptr,
p_ds_global, {},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, {},
b_element_op, gemm_element_ops,
c_element_op, {},
ds_element_op, reduce_element_ops,
ds_element_op); reduce_element_ops);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -215,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -215,7 +220,7 @@ int main(int argc, char* argv[])
// [CAUSION]: launch_and_time_kernel will not initialize D. // [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail // If we evaluate kernel multiple time but without initialize D. Verification will fail
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest()); reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
...@@ -223,7 +228,7 @@ int main(int argc, char* argv[]) ...@@ -223,7 +228,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d_device_buf.FromDevice(d_m_device_result.mData.data()); reduce_device_buf.FromDevice(reduce_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -233,27 +238,27 @@ int main(int argc, char* argv[]) ...@@ -233,27 +238,27 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; auto reduce_op = ReduceOps{}[ck::Number<0>{}];
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>(); ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ReduceAccDataType curr_val = ReduceAccDataType curr_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
d_reduce_op(d_acc, curr_val); reduce_op(reduce_acc, curr_val);
}; };
d_m_host_result(m) = d_acc; reduce_m_host_result(m) = reduce_acc;
} }
pass = ck::utils::check_err(c_m_n_device_result.mData, pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData, c_m_n_host_result.mData,
"Error: Incorrect results c") && "Error: Incorrect results c") &&
ck::utils::check_err(d_m_device_result.mData, ck::utils::check_err(reduce_m_device_result.mData,
d_m_host_result.mData, reduce_m_host_result.mData,
"Error: Incorrect results d", "Error: Incorrect results d",
1e-3, 1e-3,
1e-3); 1e-3);
...@@ -263,7 +268,7 @@ int main(int argc, char* argv[]) ...@@ -263,7 +268,7 @@ int main(int argc, char* argv[])
{ {
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>( DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
gemm_reduceMax_ave_time, M, N, K); gemm_reduceMax_ave_time, M, N, K);
} }
......
...@@ -33,27 +33,27 @@ using BDataType = F16; ...@@ -33,27 +33,27 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization = ...@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp, BElementOp,
CElementOp>; CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType> template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M; sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
...@@ -158,22 +158,22 @@ int main(int argc, char* argv[]) ...@@ -158,22 +158,22 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result( Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result( Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result( Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result( Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -191,39 +191,48 @@ int main(int argc, char* argv[]) ...@@ -191,39 +191,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOps{}; auto passthrough = UnaryIdenticElementOp{};
auto dxs_out_element_op = DxsOutElementOps{N, N}; auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), nullptr,
dxs_global, {},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, {},
b_element_op, gemm_element_ops,
c_element_op, {},
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op); reduce_out_element_ops);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -232,9 +241,9 @@ int main(int argc, char* argv[]) ...@@ -232,9 +241,9 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// init DO, D1 to 0 // init reducetion buffer to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test // will not be correct. need to set time_kernel = false for correctness test
...@@ -244,8 +253,8 @@ int main(int argc, char* argv[]) ...@@ -244,8 +253,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data()); reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -255,42 +264,40 @@ int main(int argc, char* argv[]) ...@@ -255,42 +264,40 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{}; auto reduce0_op = ReduceOp0{};
auto d1_reduce_op = D1ReduceOp{}; auto reduce1_op = ReduceOp1{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val; ReduceAccDataType square_c_val;
ReduceAccDataType d1_val; square(square_c_val, c_val);
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); reduce0_op(reduce0_acc, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); reduce1_op(reduce1_acc, square_c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); div(reduce0_acc, reduce0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); div(reduce1_acc, reduce1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
pass = ck::utils::check_err(c_m_n_device_result.mData, pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData, c_m_n_host_result.mData,
"Error: Incorrect results c") && "Error: Incorrect results c") &&
ck::utils::check_err(d0_m_device_result.mData, ck::utils::check_err(reduce0_m_device_result.mData,
d0_m_host_result.mData, reduce0_m_host_result.mData,
"Error: Incorrect results d0", "Error: Incorrect results d0",
1e-4, 1e-4,
1e-5) && 1e-5) &&
ck::utils::check_err(d1_m_device_result.mData, ck::utils::check_err(reduce1_m_device_result.mData,
d1_m_host_result.mData, reduce1_m_host_result.mData,
"Error: Incorrect results d1", "Error: Incorrect results d1",
1e-3, 1e-3,
1e-5); 1e-5);
...@@ -300,7 +307,7 @@ int main(int argc, char* argv[]) ...@@ -300,7 +307,7 @@ int main(int argc, char* argv[])
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K); DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -31,26 +31,26 @@ using ADataType = F16; ...@@ -31,26 +31,26 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...@@ -143,16 +143,16 @@ int main(int argc, char* argv[]) ...@@ -143,16 +143,16 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
...@@ -177,38 +177,48 @@ int main(int argc, char* argv[]) ...@@ -177,38 +177,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{}; auto batched_gemm = DeviceBatchedGemmReduceInstance{};
auto invoker = batched_gemm.MakeInvoker(); auto invoker = batched_gemm.MakeInvoker();
auto argument = auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), nullptr,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {},
dxs_global, c_device_buf.GetDeviceBuffer(),
M, p_reduces,
N, M,
K, N,
StrideA, K,
StrideB, StrideA,
StrideC, StrideB,
a_element_op, StrideC,
b_element_op, {},
c_element_op, gemm_element_ops,
DxsInElementOps{}, {},
DxsOutElementOps{}, reduce_in_element_ops,
BatchCount); reduce_out_element_ops,
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument)) if(!batched_gemm.IsSupportedArgument(argument))
{ {
...@@ -218,8 +228,8 @@ int main(int argc, char* argv[]) ...@@ -218,8 +228,8 @@ int main(int argc, char* argv[])
} }
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test // will not be correct. need to set time_kernel = false for correctness test
...@@ -241,8 +251,8 @@ int main(int argc, char* argv[]) ...@@ -241,8 +251,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
...@@ -252,15 +262,15 @@ int main(int argc, char* argv[]) ...@@ -252,15 +262,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{}; auto reduce0_op = ReduceOp0{};
auto d1_reduce_op = D1ReduceOp{}; auto reduce1_op = ReduceOp1{};
for(int batch = 0; batch < BatchCount; ++batch) for(int batch = 0; batch < BatchCount; ++batch)
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
...@@ -271,12 +281,12 @@ int main(int argc, char* argv[]) ...@@ -271,12 +281,12 @@ int main(int argc, char* argv[])
UnaryIdenticElementOp{}(d0_val, c_val); UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); UnarySquareElementOp{}(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val); reduce0_op(reduce0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); reduce1_op(reduce1_acc, d1_val);
} }
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc); d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc); d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
} }
......
...@@ -99,15 +99,17 @@ int main() ...@@ -99,15 +99,17 @@ int main()
a_m_n_device_buf.ToDevice(a_m_n.mData.data()); a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data()); b_n_device_buf.ToDevice(b_n.mData.data());
std::array<const void*, 2> input = {a_m_n_device_buf.GetDeviceBuffer(),
b_n_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {Stride, 1};
std::vector<ck::index_t> b_strides = {0, 1};
std::vector<ck::index_t> c_strides = {Stride, 1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), auto argument = broadcastAdd.MakeArgumentPointer(
b_n_device_buf.GetDeviceBuffer(), input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{});
c_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{Stride, 1},
{0, 1}, // broadcast in first dimension
{Stride, 1},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -81,18 +81,24 @@ int main() ...@@ -81,18 +81,24 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1, 0, 0};
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument =
a_m_device_buf.GetDeviceBuffer(), broadcastAdd.MakeArgumentPointer(input,
b_m_n_k_device_buf.GetDeviceBuffer(), output,
c_m_n_k_device_buf.GetDeviceBuffer(), std::vector<ck::index_t>{mnk.begin(), mnk.end()},
std::vector<ck::index_t>{mnk.begin(), mnk.end()}, {a_strides, b_strides},
{1, 0, 0}, // broadcast A on second and third dimension {c_strides},
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(), Add{});
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -79,15 +79,17 @@ int main() ...@@ -79,15 +79,17 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data()); b_m_device_buf.ToDevice(b_m.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), auto argument = broadcastAdd.MakeArgumentPointer(
b_m_device_buf.GetDeviceBuffer(), input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
c_m_device_buf.GetDeviceBuffer(),
{M},
{1},
{1},
{1},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -81,16 +81,22 @@ int main() ...@@ -81,16 +81,22 @@ int main()
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data()); b_device_buf.ToDevice(b.mData.data());
std::array<const void*, 2> input = {a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument =
a_device_buf.GetDeviceBuffer(), broadcastAdd.MakeArgumentPointer(input,
b_device_buf.GetDeviceBuffer(), output,
c_device_buf.GetDeviceBuffer(), std::vector<ck::index_t>{nchw.begin(), nchw.end()},
std::vector<ck::index_t>{nchw.begin(), nchw.end()}, {{a_strides}, b_strides},
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, {c_strides},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, Add{});
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -35,7 +35,7 @@ template <typename ADataType, ...@@ -35,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector, index_t DScalarPerVector,
index_t EScalarPerVector, index_t EScalarPerVector,
index_t FScalarPerVector> index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator ...@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(std::array<const void*, 5> p_inputs,
const BDataType* p_b, std::array<void*, 1> p_outputs,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
...@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator ...@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides, std::vector<index_t> f_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_inputs[0]),
p_b, static_cast<const BDataType*>(p_inputs[1]),
p_c, static_cast<const CDataType*>(p_inputs[2]),
p_d, static_cast<const DDataType*>(p_inputs[3]),
p_e, static_cast<const EDataType*>(p_inputs[4]),
p_f, static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, a_strides,
b_strides, b_strides,
...@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator ...@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor}; functor};
} }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 5> p_inputs,
const void* p_c, std::array<void*, 1> p_outputs,
const void* p_d, std::vector<index_t> lengths,
const void* p_e, std::vector<std::vector<index_t>> input_strides,
void* p_f, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> lengths, ElementwiseFunctor functor) override
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_c), static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_d), static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_e), static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_f), static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, input_strides[2],
d_strides, input_strides[3],
e_strides, input_strides[4],
f_strides, output_strides[0],
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
}; {
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceBatchedGemmReducePtr =
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
...@@ -25,7 +26,7 @@ template <typename ADataType, ...@@ -25,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,91 +9,34 @@ namespace ck { ...@@ -9,91 +9,34 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, NumDTensor> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, NumDTensor> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_ops,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment