Unverified Commit 12235112 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

external api for gemm + layernorm (#285)

* Extract base class for elementwise

* Refactor interface of DeviceGemmReduce. Do not use tuple in interface

* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add

* Unify base class of gemm + reduce and gemm + bias + add + reduce

* 1. Rename gemm_bias_add_reduce for external api
 2. Refine cmake

* Add normalize device operation

* [What] Reorder the argument
[Why] Because d0 is also the input of c.

* Add type string

* Add example of gemm_bias_add_layernorm  via external api

* Refactor example code

* clang-format

* Fix compile error

* clang-format

* Add external api for gemm_add_add_layernorm and normalize

* Add client example

* clang-format
parent aebd211c
add_executable(gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
target_link_libraries(gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using BiasDataType = F32;
using CDataType = F16;
using D0DataType = F16;
using ReduceDataType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <typename gemm_reduce_op_ptr>
bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op,
const void* p_a,
const void* p_b,
const void* p_bias,
const void* p_d0,
void* p_c,
void* p_mean,
void* p_square_mean,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int StrideD0,
bool time_kernel)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
auto passOp = PassThrough{};
auto squareOp = UnarySquareElementOp{};
auto divOp = UnaryDivElementOp{N};
auto argument_ptr =
p_op->MakeArgumentPointer(p_a,
p_b,
p_bias,
{p_d0},
p_c,
{p_mean, p_square_mean},
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
{&passOp, &passOp, &passOp}, // functor for a, b, c
{&passOp}, // functor for d0
{&passOp, &squareOp}, // functor for inputs of reduction
{&divOp, &divOp}); // functor for outputs of reduction
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
// If we evaluate running time of gemm_reduce. The output may wrong.
// Because we need to initialize the reduction tensor before runing the kernel.
// However we run kernel many times for time_kernel = trie without reinitialize the out
// of reduction tensor.
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
template <typename normalize_op_ptr>
bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
const void* p_x,
const void* p_mean,
const void* p_square_mean,
const void* p_gamma,
const void* p_beta,
void* p_y,
int M,
int N,
int StrideX,
bool time_kernel)
{
std::array<const void*, 5> input = {p_x, p_mean, p_square_mean, p_gamma, p_beta};
std::array<void*, 1> output = {p_y};
auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
auto argument_ptr = p_op->MakeArgumentPointer(input,
output,
{M, N},
{{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{{StrideX, 1}},
ck::tensor_operation::element_wise::Normalize{});
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideD0 = 1024;
const auto gemm_reduce_ptrs = ck::tensor_operation::device::device_gemm_instance::
get_device_gemm_add_add_mean_squaremean_instances<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD0, CLayout{}));
SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
bool b_time_kernel = true;
bool b_only_run_first_kernel = true;
// layernorm => (1) + (2)
// (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c)
// (2). normalize(c, mean, square_mean, gamma, beta)
for(auto& gemm_reduce_ptr : gemm_reduce_ptrs)
{
// run first available kernel
if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr,
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
d0_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideD0,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
for(auto& normalize_ptr : normalize_ptrs)
{
if(RunDeviceNormalize2D(normalize_ptr,
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
layerNorm_device_buf.GetDeviceBuffer(),
M,
N,
StrideC,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << normalize_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
\ No newline at end of file
...@@ -7,3 +7,4 @@ find_package(hip REQUIRED PATHS /opt/rocm) ...@@ -7,3 +7,4 @@ find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(02_gemm_add_add_fastgelu) add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm)
...@@ -33,19 +33,19 @@ using BDataType = F16; ...@@ -33,19 +33,19 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F64; using ReduceDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using DsReduceOp = ck::Tuple<ck::reduce::Max>; using ReduceOps = ck::Tuple<ck::reduce::Max>;
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>; using ReduceElementOps = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
...@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization = ...@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp, BElementOp,
CElementOp>; CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType> template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M; sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
...@@ -148,17 +148,17 @@ int main(int argc, char* argv[]) ...@@ -148,17 +148,17 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_host_result( Tensor<ReduceDataType> reduce_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_device_result( Tensor<ReduceDataType> reduce_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -176,35 +176,40 @@ int main(int argc, char* argv[]) ...@@ -176,35 +176,40 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
reduce_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto ds_element_op = DsElementOp{}; auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}];
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer())); std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
std::array<void*, 1> reduce_element_ops = {&reduce_element_op};
std::array<void*, 1> p_reduces = {reduce_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), nullptr,
p_ds_global, {},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, {},
b_element_op, gemm_element_ops,
c_element_op, {},
ds_element_op, reduce_element_ops,
ds_element_op); reduce_element_ops);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -215,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -215,7 +220,7 @@ int main(int argc, char* argv[])
// [CAUSION]: launch_and_time_kernel will not initialize D. // [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail // If we evaluate kernel multiple time but without initialize D. Verification will fail
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest()); reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
...@@ -223,7 +228,7 @@ int main(int argc, char* argv[]) ...@@ -223,7 +228,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d_device_buf.FromDevice(d_m_device_result.mData.data()); reduce_device_buf.FromDevice(reduce_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -233,27 +238,27 @@ int main(int argc, char* argv[]) ...@@ -233,27 +238,27 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; auto reduce_op = ReduceOps{}[ck::Number<0>{}];
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>(); ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ReduceAccDataType curr_val = ReduceAccDataType curr_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
d_reduce_op(d_acc, curr_val); reduce_op(reduce_acc, curr_val);
}; };
d_m_host_result(m) = d_acc; reduce_m_host_result(m) = reduce_acc;
} }
pass = ck::utils::check_err(c_m_n_device_result.mData, pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData, c_m_n_host_result.mData,
"Error: Incorrect results c") && "Error: Incorrect results c") &&
ck::utils::check_err(d_m_device_result.mData, ck::utils::check_err(reduce_m_device_result.mData,
d_m_host_result.mData, reduce_m_host_result.mData,
"Error: Incorrect results d", "Error: Incorrect results d",
1e-3, 1e-3,
1e-3); 1e-3);
...@@ -263,7 +268,7 @@ int main(int argc, char* argv[]) ...@@ -263,7 +268,7 @@ int main(int argc, char* argv[])
{ {
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>( DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
gemm_reduceMax_ave_time, M, N, K); gemm_reduceMax_ave_time, M, N, K);
} }
......
...@@ -33,27 +33,27 @@ using BDataType = F16; ...@@ -33,27 +33,27 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization = ...@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp, BElementOp,
CElementOp>; CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType> template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M; sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
...@@ -158,22 +158,22 @@ int main(int argc, char* argv[]) ...@@ -158,22 +158,22 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result( Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result( Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result( Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result( Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -191,39 +191,48 @@ int main(int argc, char* argv[]) ...@@ -191,39 +191,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOps{}; auto passthrough = UnaryIdenticElementOp{};
auto dxs_out_element_op = DxsOutElementOps{N, N}; auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), nullptr,
dxs_global, {},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, {},
b_element_op, gemm_element_ops,
c_element_op, {},
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op); reduce_out_element_ops);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -232,9 +241,9 @@ int main(int argc, char* argv[]) ...@@ -232,9 +241,9 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// init DO, D1 to 0 // init reducetion buffer to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test // will not be correct. need to set time_kernel = false for correctness test
...@@ -244,8 +253,8 @@ int main(int argc, char* argv[]) ...@@ -244,8 +253,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data()); reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -255,42 +264,40 @@ int main(int argc, char* argv[]) ...@@ -255,42 +264,40 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{}; auto reduce0_op = ReduceOp0{};
auto d1_reduce_op = D1ReduceOp{}; auto reduce1_op = ReduceOp1{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val; ReduceAccDataType square_c_val;
ReduceAccDataType d1_val; square(square_c_val, c_val);
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); reduce0_op(reduce0_acc, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); reduce1_op(reduce1_acc, square_c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); div(reduce0_acc, reduce0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); div(reduce1_acc, reduce1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
pass = ck::utils::check_err(c_m_n_device_result.mData, pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData, c_m_n_host_result.mData,
"Error: Incorrect results c") && "Error: Incorrect results c") &&
ck::utils::check_err(d0_m_device_result.mData, ck::utils::check_err(reduce0_m_device_result.mData,
d0_m_host_result.mData, reduce0_m_host_result.mData,
"Error: Incorrect results d0", "Error: Incorrect results d0",
1e-4, 1e-4,
1e-5) && 1e-5) &&
ck::utils::check_err(d1_m_device_result.mData, ck::utils::check_err(reduce1_m_device_result.mData,
d1_m_host_result.mData, reduce1_m_host_result.mData,
"Error: Incorrect results d1", "Error: Incorrect results d1",
1e-3, 1e-3,
1e-5); 1e-5);
...@@ -300,7 +307,7 @@ int main(int argc, char* argv[]) ...@@ -300,7 +307,7 @@ int main(int argc, char* argv[])
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K); DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -31,26 +31,26 @@ using ADataType = F16; ...@@ -31,26 +31,26 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...@@ -143,16 +143,16 @@ int main(int argc, char* argv[]) ...@@ -143,16 +143,16 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
...@@ -177,38 +177,48 @@ int main(int argc, char* argv[]) ...@@ -177,38 +177,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{}; auto batched_gemm = DeviceBatchedGemmReduceInstance{};
auto invoker = batched_gemm.MakeInvoker(); auto invoker = batched_gemm.MakeInvoker();
auto argument = auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), nullptr,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {},
dxs_global, c_device_buf.GetDeviceBuffer(),
M, p_reduces,
N, M,
K, N,
StrideA, K,
StrideB, StrideA,
StrideC, StrideB,
a_element_op, StrideC,
b_element_op, {},
c_element_op, gemm_element_ops,
DxsInElementOps{}, {},
DxsOutElementOps{}, reduce_in_element_ops,
BatchCount); reduce_out_element_ops,
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument)) if(!batched_gemm.IsSupportedArgument(argument))
{ {
...@@ -218,8 +228,8 @@ int main(int argc, char* argv[]) ...@@ -218,8 +228,8 @@ int main(int argc, char* argv[])
} }
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test // will not be correct. need to set time_kernel = false for correctness test
...@@ -241,8 +251,8 @@ int main(int argc, char* argv[]) ...@@ -241,8 +251,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
...@@ -252,15 +262,15 @@ int main(int argc, char* argv[]) ...@@ -252,15 +262,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{}; auto reduce0_op = ReduceOp0{};
auto d1_reduce_op = D1ReduceOp{}; auto reduce1_op = ReduceOp1{};
for(int batch = 0; batch < BatchCount; ++batch) for(int batch = 0; batch < BatchCount; ++batch)
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
...@@ -271,12 +281,12 @@ int main(int argc, char* argv[]) ...@@ -271,12 +281,12 @@ int main(int argc, char* argv[])
UnaryIdenticElementOp{}(d0_val, c_val); UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); UnarySquareElementOp{}(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val); reduce0_op(reduce0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); reduce1_op(reduce1_acc, d1_val);
} }
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc); d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc); d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
} }
......
...@@ -99,15 +99,17 @@ int main() ...@@ -99,15 +99,17 @@ int main()
a_m_n_device_buf.ToDevice(a_m_n.mData.data()); a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data()); b_n_device_buf.ToDevice(b_n.mData.data());
std::array<const void*, 2> input = {a_m_n_device_buf.GetDeviceBuffer(),
b_n_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {Stride, 1};
std::vector<ck::index_t> b_strides = {0, 1};
std::vector<ck::index_t> c_strides = {Stride, 1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), auto argument = broadcastAdd.MakeArgumentPointer(
b_n_device_buf.GetDeviceBuffer(), input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{});
c_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{Stride, 1},
{0, 1}, // broadcast in first dimension
{Stride, 1},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -81,18 +81,24 @@ int main() ...@@ -81,18 +81,24 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1, 0, 0};
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument =
a_m_device_buf.GetDeviceBuffer(), broadcastAdd.MakeArgumentPointer(input,
b_m_n_k_device_buf.GetDeviceBuffer(), output,
c_m_n_k_device_buf.GetDeviceBuffer(), std::vector<ck::index_t>{mnk.begin(), mnk.end()},
std::vector<ck::index_t>{mnk.begin(), mnk.end()}, {a_strides, b_strides},
{1, 0, 0}, // broadcast A on second and third dimension {c_strides},
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(), Add{});
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -79,15 +79,17 @@ int main() ...@@ -79,15 +79,17 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data()); b_m_device_buf.ToDevice(b_m.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), auto argument = broadcastAdd.MakeArgumentPointer(
b_m_device_buf.GetDeviceBuffer(), input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
c_m_device_buf.GetDeviceBuffer(),
{M},
{1},
{1},
{1},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -81,16 +81,22 @@ int main() ...@@ -81,16 +81,22 @@ int main()
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data()); b_device_buf.ToDevice(b.mData.data());
std::array<const void*, 2> input = {a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument =
a_device_buf.GetDeviceBuffer(), broadcastAdd.MakeArgumentPointer(input,
b_device_buf.GetDeviceBuffer(), output,
c_device_buf.GetDeviceBuffer(), std::vector<ck::index_t>{nchw.begin(), nchw.end()},
std::vector<ck::index_t>{nchw.begin(), nchw.end()}, {{a_strides}, b_strides},
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, {c_strides},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, Add{});
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
......
...@@ -31,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -31,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using C0DataType = F32; using BiasDataType = F32;
using C1DataType = F16; using D0DataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
...@@ -50,17 +50,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -50,17 +50,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Relu; using CElementOp = ck::tensor_operation::element_wise::Relu;
using C1ElementOp = PassThrough; using D0ElementOp = PassThrough;
using ReduceSumOp = ck::reduce::Add; using ReduceSumOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>; using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -69,11 +69,11 @@ static constexpr auto GemmSpecialization = ...@@ -69,11 +69,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -89,8 +89,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; ...@@ -89,8 +89,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType, ReduceDataType,
DDataType, ReduceDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType, LayerNormOutDataType,
...@@ -125,10 +125,10 @@ auto f_host_tensor_descriptor2d = ...@@ -125,10 +125,10 @@ auto f_host_tensor_descriptor2d =
}; };
template <typename CDataType, template <typename CDataType,
typename DDataType, typename ReduceDataType,
typename AccDataType, typename AccDataType,
typename C0DataType, typename BiasDataType,
typename C1DataType, typename D0DataType,
typename A_functor, typename A_functor,
typename B_functor, typename B_functor,
typename C_functor, typename C_functor,
...@@ -136,8 +136,8 @@ template <typename CDataType, ...@@ -136,8 +136,8 @@ template <typename CDataType,
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<ADataType>& b_k_n,
const Tensor<C0DataType>& bias_n, const Tensor<BiasDataType>& bias_n,
const Tensor<C1DataType>& c1_m_n, const Tensor<D0DataType>& c1_m_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n, const Tensor<GammaDataType>& beta_n,
A_functor a_element_op, A_functor a_element_op,
...@@ -150,8 +150,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -150,8 +150,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N; int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N}; auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
...@@ -196,8 +196,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -196,8 +196,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<DDataType>(mean_acc); mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
} }
// LayerNorm // LayerNorm
...@@ -213,7 +213,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -213,7 +213,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
static_cast<AccDataType>(meanSquare_m(m)), static_cast<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)), static_cast<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n))); static_cast<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<DDataType>(out_acc); out_m_n(m, n) = static_cast<ReduceDataType>(out_acc);
} }
} }
} }
...@@ -221,9 +221,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -221,9 +221,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType, typename BiasDataType,
typename C1DataType, typename D0DataType,
typename DDataType, typename ReduceDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -231,12 +231,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -231,12 +231,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(C1DataType) * M * N + sizeof(DDataType) * M + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M; sizeof(ReduceDataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M + std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -260,15 +260,15 @@ int main() ...@@ -260,15 +260,15 @@ int main()
ck::index_t StrideA = 1024; ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024; ck::index_t StrideC = 1024;
ck::index_t StrideC1 = 1024; ck::index_t StrideD0 = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
...@@ -276,18 +276,18 @@ int main() ...@@ -276,18 +276,18 @@ int main()
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1}); bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5}); c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
...@@ -297,44 +297,45 @@ int main() ...@@ -297,44 +297,45 @@ int main()
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data()); bias_device_buf.ToDevice(bias_n.mData.data());
c1_device_buf.ToDevice(c1_m_n.mData.data()); d0_device_buf.ToDevice(c1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto c1_element_op = C1ElementOp{}; auto d_element_op = D0ElementOp{};
auto dxs_global = std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOps{}; auto passthrough = UnaryIdenticElementOp{};
auto dxs_out_element_op = DxsOutElementOps{N, N}; auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square // Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{}; auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), bias_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {d0_device_buf.GetDeviceBuffer()},
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()), c_device_buf.GetDeviceBuffer(),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()), p_reduces,
dxs_global, M,
M, N,
N, K,
K, StrideA,
StrideA, StrideB,
StrideB, StrideC,
StrideC, {StrideD0},
StrideC1, gemm_element_ops,
a_element_op, {&d_element_op},
b_element_op, reduce_in_element_ops,
c_element_op, reduce_out_element_ops);
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
...@@ -347,23 +348,25 @@ int main() ...@@ -347,23 +348,25 @@ int main()
reduceMeanSquare_device_buf.SetZero(); reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{}; auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker(); auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument( auto normalize_argument = normalize.MakeArgument(input,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), output,
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()), {M, N},
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()), {StrideC, 1},
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()), {1, 0},
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()), {1, 0},
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()), {0, 1},
{M, N}, {0, 1},
{StrideC, 1}, {StrideC, 1},
{1, 0}, NormalizeFunctor{});
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument))
{ {
...@@ -381,19 +384,19 @@ int main() ...@@ -381,19 +384,19 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n, host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
bias_n, bias_n,
c1_m_n, c1_m_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op, d_element_op,
M, M,
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -416,9 +419,9 @@ int main() ...@@ -416,9 +419,9 @@ int main()
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, CDataType,
C0DataType, BiasDataType,
C1DataType, D0DataType,
DDataType, ReduceDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
...@@ -33,8 +33,8 @@ using BDataType = F16; ...@@ -33,8 +33,8 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using DDataType = F32; using ReduceDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>; using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
...@@ -48,15 +48,15 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -48,15 +48,15 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceSumOp = ck::reduce::Add; using ReduceSumOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>; using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp = using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -65,11 +65,11 @@ static constexpr auto GemmSpecialization = ...@@ -65,11 +65,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -85,8 +85,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; ...@@ -85,8 +85,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType, ReduceDataType,
DDataType, ReduceDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType, LayerNormOutDataType,
...@@ -121,7 +121,7 @@ auto f_host_tensor_descriptor2d = ...@@ -121,7 +121,7 @@ auto f_host_tensor_descriptor2d =
}; };
template <typename CDataType, template <typename CDataType,
typename DDataType, typename ReduceDataType,
typename A_functor, typename A_functor,
typename B_functor, typename B_functor,
typename C_functor> typename C_functor>
...@@ -140,8 +140,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -140,8 +140,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N; int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N}; auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
...@@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<DDataType>(mean_acc); mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
} }
// LayerNorm // LayerNorm
...@@ -197,7 +197,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -197,7 +197,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename DDataType, typename ReduceDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -205,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -205,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M; sizeof(ReduceDataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M + std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -237,8 +237,8 @@ int main() ...@@ -237,8 +237,8 @@ int main()
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
...@@ -252,8 +252,8 @@ int main() ...@@ -252,8 +252,8 @@ int main()
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
...@@ -265,35 +265,40 @@ int main() ...@@ -265,35 +265,40 @@ int main()
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto dxs_global = std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOps{}; auto passthrough = UnaryIdenticElementOp{};
auto dxs_out_element_op = DxsOutElementOps{N, N}; auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square // Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmReduceInstance{}; auto gemmReduce = DeviceGemmReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), nullptr,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {},
dxs_global, c_device_buf.GetDeviceBuffer(),
M, p_reduces,
N, M,
K, N,
StrideA, K,
StrideB, StrideA,
StrideC, StrideB,
a_element_op, StrideC,
b_element_op, {},
c_element_op, gemm_element_ops,
dxs_in_element_op, {},
dxs_out_element_op); reduce_in_element_ops,
reduce_out_element_ops);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
...@@ -306,23 +311,25 @@ int main() ...@@ -306,23 +311,25 @@ int main()
reduceMeanSquare_device_buf.SetZero(); reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{}; auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker(); auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument( auto normalize_argument = normalize.MakeArgument(input,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), output,
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()), {M, N},
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()), {StrideC, 1},
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()), {1, 0},
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()), {1, 0},
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()), {0, 1},
{M, N}, {0, 1},
{StrideC, 1}, {StrideC, 1},
{1, 0}, NormalizeFunctor{});
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument))
{ {
...@@ -340,16 +347,16 @@ int main() ...@@ -340,16 +347,16 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, DDataType>(host_layerNorm_m_n, host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
M, M,
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -372,7 +379,7 @@ int main() ...@@ -372,7 +379,7 @@ int main()
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, CDataType,
DDataType, ReduceDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
...@@ -35,7 +35,7 @@ template <typename ADataType, ...@@ -35,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector, index_t DScalarPerVector,
index_t EScalarPerVector, index_t EScalarPerVector,
index_t FScalarPerVector> index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator ...@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(std::array<const void*, 5> p_inputs,
const BDataType* p_b, std::array<void*, 1> p_outputs,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
...@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator ...@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides, std::vector<index_t> f_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_inputs[0]),
p_b, static_cast<const BDataType*>(p_inputs[1]),
p_c, static_cast<const CDataType*>(p_inputs[2]),
p_d, static_cast<const DDataType*>(p_inputs[3]),
p_e, static_cast<const EDataType*>(p_inputs[4]),
p_f, static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, a_strides,
b_strides, b_strides,
...@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator ...@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor}; functor};
} }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 5> p_inputs,
const void* p_c, std::array<void*, 1> p_outputs,
const void* p_d, std::vector<index_t> lengths,
const void* p_e, std::vector<std::vector<index_t>> input_strides,
void* p_f, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> lengths, ElementwiseFunctor functor) override
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_c), static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_d), static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_e), static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_f), static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, input_strides[2],
d_strides, input_strides[3],
e_strides, input_strides[4],
f_strides, output_strides[0],
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
}; {
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceBatchedGemmReducePtr =
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -23,16 +23,16 @@ namespace device { ...@@ -23,16 +23,16 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock,
typename ComputeBasePrtOfBatch, typename ComputeBasePrtOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop>
...@@ -44,18 +44,18 @@ __global__ void ...@@ -44,18 +44,18 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const ReduceInElementwiseOperations reduce_in_element_ops,
const DxsReduceAccElementwiseOperation dxs_out_element_op, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -71,10 +71,10 @@ __global__ void ...@@ -71,10 +71,10 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
}); });
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -82,33 +82,33 @@ __global__ void ...@@ -82,33 +82,33 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_ds_grid, p_reduces_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_ds_grid; ignore = p_reduces_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = dxs_in_element_op; ignore = reduce_in_element_ops;
ignore = dxs_out_element_op; ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = reduce_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_; ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif #endif
...@@ -126,14 +126,14 @@ template <typename ALayout, ...@@ -126,14 +126,14 @@ template <typename ALayout,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -168,12 +168,7 @@ template <typename ALayout, ...@@ -168,12 +168,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
: public DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
} }
// assume D is packed tensor // assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeReduceGridDescriptor_M(index_t MRaw)
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
...@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, ReduceOperations,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
DGridDesc_M, ReduceGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, ReduceInElementwiseOperations reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op, ReduceAccElementwiseOperations reduce_out_element_ops,
index_t Batch) index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid}, p_reduces_grid_{p_reduces_grid},
Batch_(Batch), Batch_(Batch),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, reduce_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())}, type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op}, reduce_in_element_ops_{reduce_in_element_ops},
dxs_out_element_op_{dxs_out_element_op} reduce_out_element_ops_{reduce_out_element_ops}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
} }
} }
...@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_; ReducePtrsGlobal p_reduces_grid_;
index_t Batch_; index_t Batch_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; ReduceInElementwiseOperations reduce_in_element_ops_;
DxsReduceAccElementwiseOperation dxs_out_element_op_; ReduceAccElementwiseOperations reduce_out_element_ops_;
}; };
// Invoker // Invoker
...@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl; << std::endl;
} }
#endif #endif
...@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.Batch_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.Batch_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -824,38 +820,76 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -824,38 +820,76 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
} }
} }
static auto MakeArgument(const ADataType* p_a, static constexpr int NumReduce = ReduceOperations::Size();
const BDataType* p_b, static auto MakeArgument(const void* p_a,
CDataType* p_c, const void* p_b,
DPtrsGlobal p_dxs, const void* p_bias,
index_t MRaw, std::array<const void*, 0> p_ds,
index_t NRaw, void* p_c,
index_t KRaw, std::array<void*, NumReduce> p_reduces,
index_t StrideA, ck::index_t M,
index_t StrideB, ck::index_t N,
index_t StrideC, ck::index_t K,
AElementwiseOperation a_element_op, ck::index_t StrideA,
BElementwiseOperation b_element_op, ck::index_t StrideB,
CElementwiseOperation c_element_op, ck::index_t StrideC,
DxsInElementwiseOperation dxs_in_element_op, std::array<ck::index_t, 0> StrideDs,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch) index_t Batch)
{ {
return Argument{p_a, (void)p_bias;
p_b, (void)p_ds;
p_c, (void)StrideDs;
p_dxs, (void)d_element_ops;
MRaw,
NRaw, ReducePtrsGlobal reduce_tuple = generate_tuple(
KRaw, [&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
Batch}; Batch};
} }
...@@ -865,37 +899,73 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle ...@@ -865,37 +899,73 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
index_t MRaw, ck::index_t M,
index_t NRaw, ck::index_t N,
index_t KRaw, ck::index_t K,
index_t StrideA, ck::index_t StrideA,
index_t StrideB, ck::index_t StrideB,
index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, 0> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, 0> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch) override index_t Batch = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs)); (void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
dxs_tuple, reduce_tuple,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
Batch); Batch);
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
...@@ -25,7 +26,7 @@ template <typename ADataType, ...@@ -25,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -29,20 +29,20 @@ template <typename ALayout, ...@@ -29,20 +29,20 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType, typename BiasDataType,
typename C1DataType, typename D0DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename D0ElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -77,13 +77,7 @@ template <typename ALayout, ...@@ -77,13 +77,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
: public DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
...@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
} }
// assume D is packed tensor // assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeReduceGridDescriptor_M(index_t MRaw)
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
...@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
C0DataType, BiasDataType,
C1DataType, D0DataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, D0ElementwiseOperation,
DxsReduceOperation, ReduceOperations,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_M_N, C0GridDesc_M_N,
C1GridDesc_M_N, C1GridDesc_M_N,
DGridDesc_M, ReduceGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const C0DataType* p_c0_grid, const BiasDataType* p_bias_grid,
const C1DataType* p_c1_grid, const D0DataType* p_d0_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op, D0ElementwiseOperation d0_element_op,
DxsInElementwiseOperation dxs_in_element_op, ReduceInElementwiseOperations reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op) ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid}, p_bias_grid_{p_bias_grid},
p_c1_grid_{p_c1_grid}, p_d0_grid_{p_d0_grid},
p_ds_grid_{p_ds_grid}, p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c1_element_op_{c1_element_op}, d0_element_op_{d0_element_op},
dxs_in_element_op_{dxs_in_element_op}, reduce_in_element_ops_{reduce_in_element_ops},
dxs_out_element_op_{dxs_out_element_op} reduce_out_element_ops_{reduce_out_element_ops}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_); c1_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
} }
} }
...@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const C0DataType* p_c0_grid_; const BiasDataType* p_bias_grid_;
const C1DataType* p_c1_grid_; const D0DataType* p_d0_grid_;
DPtrsGlobal p_ds_grid_; ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_; c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_; c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
C1ElementwiseOperation c1_element_op_; D0ElementwiseOperation d0_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; ReduceInElementwiseOperations reduce_in_element_ops_;
DxsReduceAccElementwiseOperation dxs_out_element_op_; ReduceAccElementwiseOperations reduce_out_element_ops_;
}; };
// Invoker // Invoker
...@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType, BiasDataType,
C1DataType, D0DataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, D0ElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_, arg.p_bias_grid_,
arg.p_c1_grid_, arg.p_d0_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.c1_element_op_, arg.d0_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
...@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType, BiasDataType,
C1DataType, D0DataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, D0ElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_, arg.p_bias_grid_,
arg.p_c1_grid_, arg.p_d0_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.c1_element_op_, arg.d0_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static constexpr int NumReduce = ReduceOperations::Size();
const BDataType* p_b, static auto MakeArgument(const void* p_a,
CDataType* p_c, const void* p_b,
const C0DataType* p_c0, const void* p_bias,
const C1DataType* p_c1, std::array<const void*, 1> p_ds,
DPtrsGlobal p_dxs, void* p_c,
index_t MRaw, std::array<void*, NumReduce> p_reduces,
index_t NRaw, ck::index_t M,
index_t KRaw, ck::index_t N,
index_t StrideA, ck::index_t K,
index_t StrideB, ck::index_t StrideA,
index_t StrideC, ck::index_t StrideB,
index_t StrideC1, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, 1> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, 1> d_element_ops,
C1ElementwiseOperation c1_element_op, std::array<void*, NumReduce> reduce_in_element_op,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_out_element_op)
DxsReduceAccElementwiseOperation dxs_out_element_op)
{ {
return Argument{p_a, ReducePtrsGlobal reduce_tuple = generate_tuple(
p_b, [&](auto I) {
p_c, auto tmp = ReducePtrsGlobal{}[I];
p_c0, using T = remove_pointer_t<decltype(tmp)>;
p_c1, return static_cast<T*>(p_reduces[I]);
p_dxs, },
MRaw, Number<NumReduce>{});
NRaw,
KRaw, ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1, StrideDs[0],
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op, d_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op}; reduce_out_element_ops};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c, void* p_c,
const void* p_c0, std::array<void*, NumReduce> p_reduces,
const void* p_c1, ck::index_t M,
void* p_dxs, ck::index_t N,
index_t MRaw, ck::index_t K,
index_t NRaw, ck::index_t StrideA,
index_t KRaw, ck::index_t StrideB,
index_t StrideA, ck::index_t StrideC,
index_t StrideB, std::array<ck::index_t, 1> StrideDs,
index_t StrideC, std::array<void*, 3> gemm_element_ops,
index_t StrideC1, std::array<void*, 1> d_element_ops,
AElementwiseOperation a_element_op, std::array<void*, NumReduce> reduce_in_element_op,
BElementwiseOperation b_element_op, std::array<void*, NumReduce> reduce_out_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs)); ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0), static_cast<const BiasDataType*>(p_bias),
static_cast<const C1DataType*>(p_c1), static_cast<const D0DataType*>(p_ds[0]),
dxs_tuple, reduce_tuple,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1, StrideDs[0],
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op, d_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op); reduce_out_element_ops);
} }
// polymorphic // polymorphic
...@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemmReduce_Xdl_CShuffle" str << "DeviceGemmBiasAddReduce_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -9,91 +9,34 @@ namespace ck { ...@@ -9,91 +9,34 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, NumDTensor> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, NumDTensor> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_ops,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -32,14 +32,14 @@ template <typename ALayout, ...@@ -32,14 +32,14 @@ template <typename ALayout,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsReduceAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -74,11 +74,7 @@ template <typename ALayout, ...@@ -74,11 +74,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
...@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
} }
} }
// assume D is packed tensor // assume Reduce is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeReduceGridDescriptor_M(index_t MRaw)
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
...@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, ReduceOperations,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
DGridDesc_M, ReduceGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, ReduceInElementwiseOperations reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op) ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid}, p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op}, reduce_in_element_ops_{reduce_in_element_ops},
dxs_out_element_op_{dxs_out_element_op} reduce_out_element_ops_{reduce_out_element_ops}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
} }
} }
...@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_; ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; ReduceInElementwiseOperations reduce_in_element_ops_;
DxsReduceAccElementwiseOperation dxs_out_element_op_; ReduceAccElementwiseOperations reduce_out_element_ops_;
}; };
// Invoker // Invoker
...@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl; << std::endl;
} }
#endif #endif
...@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
...@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsReduceAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static constexpr int NumReduce = ReduceOperations::Size();
const BDataType* p_b, static auto MakeArgument(const void* p_a,
CDataType* p_c, const void* p_b,
DPtrsGlobal p_dxs, const void* p_bias,
index_t MRaw, std::array<const void*, 0> p_ds,
index_t NRaw, void* p_c,
index_t KRaw, std::array<void*, NumReduce> p_reduces,
index_t StrideA, ck::index_t M,
index_t StrideB, ck::index_t N,
index_t StrideC, ck::index_t K,
AElementwiseOperation a_element_op, ck::index_t StrideA,
BElementwiseOperation b_element_op, ck::index_t StrideB,
CElementwiseOperation c_element_op, ck::index_t StrideC,
DxsInElementwiseOperation dxs_in_element_op, std::array<ck::index_t, 0> StrideDs,
DxsReduceAccElementwiseOperation dxs_out_element_op) std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{ {
return Argument{p_a, (void)p_bias;
p_b, (void)p_ds;
p_c, (void)StrideDs;
p_dxs, (void)d_element_ops;
MRaw,
NRaw, ReducePtrsGlobal reduce_tuple = generate_tuple(
KRaw, [&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op}; reduce_out_element_ops};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
index_t MRaw, ck::index_t M,
index_t NRaw, ck::index_t N,
index_t KRaw, ck::index_t K,
index_t StrideA, ck::index_t StrideA,
index_t StrideB, ck::index_t StrideB,
index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, 0> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, 0> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_op,
index_t /* KBatch */ = 1) override ck::index_t = 1) override
{ {
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs)); (void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
dxs_tuple, reduce_tuple,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op); reduce_out_element_ops);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment