"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "cd51758fade4119b3f6233444c3bfac91ed5eba9"
Commit 8ec065fd authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gemm_bias

parents 698f3a3e 8e374781
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel
//
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using C0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct Relu
{
template <typename OutT, typename InT>
__host__ __device__ void operator()(OutT& y, const InT& x) const
{
y = x > 0 ? x : 0;
}
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
// Elementwise operation that operates on the output of matrix multiplication
// i.e., AccElementOp(A * B + bias)
using AccElementOp = Relu;
// Elementwise operation that operates on the output of layer normalization
using CElementOp = Relu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>;
// clang-format on
using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
BDataType,
CDataType,
C0DataType,
AccDataType,
AElementOp,
BElementOp,
AccElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 128;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 128;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl;
std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl;
std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl;
std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_m_n_add.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c0_bias_buf.ToDevice(c0_n_bias.mData.data());
c0_add_buf.ToDevice(c0_m_n_add.mData.data());
c0_gamma_buf.ToDevice(c0_n_gamma.mData.data());
c0_beta_buf.ToDevice(c0_n_beta.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto acc_element_op = AccElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_add_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div,
// excluding reduction steps
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N;
// extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN)
std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
auto ref_gemm = ReferenceInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
b_k_n,
c_m_n_host_result,
c0_n_bias,
c0_m_n_add,
c0_n_gamma,
c0_n_beta,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
if constexpr(std::is_same<CShuffleDataType, F32>::value)
{
pass &= ck::utils::check_err(
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
}
else if constexpr(std::is_same<CShuffleDataType, F16>::value)
{
pass &= ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c",
1e-2,
1e-2);
}
}
return pass ? 0 : 1;
}
...@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator ...@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
......
...@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{ block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01}, M01_{M01},
...@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
...@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
......
...@@ -46,13 +46,22 @@ __global__ void ...@@ -46,13 +46,22 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0; index_t left = 0;
for(index_t i = 0; i < group_count; i++) index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{ {
group_id = if(block_id < gemm_desc_ptr[group_id].BlockStart_)
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) {
? i right = group_id;
: group_id; }
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
} }
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -30,6 +30,8 @@ struct ThreadwiseReduction ...@@ -30,6 +30,8 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{ {
......
...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void> ...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void>
struct PrintAsType; struct PrintAsType;
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
}; };
template <> template <>
struct PrintAsType<ck::half_t, void> struct PrintAsType<ck::half_t, void>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const ck::half_t& p)
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
}; };
} // namespace detail } // namespace detail
...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value ...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value
template <typename T, index_t element_stride = 1, index_t row_bytes = 128> template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements) __device__ void print_shared(T const* p_shared, index_t num_elements)
{ {
using PrintType = typename detail::PrintAsType<T>::type;
constexpr index_t row_elements = row_bytes / sizeof(T); constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements), static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]"); "element_stride should between [1, row_elements]");
...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
printf("elem %5d: ", i); printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride) for(index_t j = 0; j < row_elements; j += element_stride)
{ {
printf("%.0f ", static_cast<PrintType>(p_shared[i + j])); detail::PrintAsType<T>::Print(p_shared[i + j]);
} }
printf("\n"); printf("\n");
......
...@@ -58,6 +58,33 @@ struct Add ...@@ -58,6 +58,33 @@ struct Add
} }
}; };
struct SquaredAdd
{
template <class T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <class T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
a = a + b * b;
}
};
struct Mul struct Mul
{ {
template <typename T> template <typename T>
......
...@@ -220,12 +220,24 @@ struct Tensor ...@@ -220,12 +220,24 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
template <typename OutT>
Tensor<OutT> CopyAsType()
{
Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++)
{
ret.mData[i] = static_cast<OutT>(mData[i]);
}
return ret;
}
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
Tensor& operator=(const Tensor& other) Tensor& operator=(const Tensor& other)
{ {
mDesc = other.mDesc; mDesc = other.mDesc;
mData = other.mData; mData = other.mData;
return *this;
} }
template <typename F> template <typename F>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator
{
using ReferenceGemmInstance = ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
element_wise::PassThrough>;
template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5)
{
assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> acc_layernorm(acc);
// reduce N dim
for(size_t i = 0; i < M; i++)
{
ComputeDataType sum_acc_sq = 0;
ComputeDataType sum_acc = 0;
for(size_t j = 0; j < N; j++)
{
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
sum_acc += acc_layernorm(i, j);
}
avg_acc_sq(i) = sum_acc_sq / N;
avg_acc(i) = sum_acc / N;
}
// normalize
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) =
(self(idx[0], idx[1]) - avg_acc(idx[0])) /
sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon);
});
// affine
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]);
});
// cast
result = acc_layernorm.template CopyAsType<OutDataType>();
}
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<C0DataType>& c0_m_n_add, // MxN
const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<C0DataType>& c0_n_beta, // 1xN
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_bias_{c0_n_bias},
c0_m_n_add_{c0_m_n_add},
c0_n_gamma_{c0_n_gamma},
c0_n_beta_{c0_n_beta},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op},
epsilon_{epsilon}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<C0DataType>& c0_n_bias_;
const Tensor<C0DataType>& c0_m_n_add_;
const Tensor<C0DataType>& c0_n_gamma_;
const Tensor<C0DataType>& c0_n_beta_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_;
const CDataType epsilon_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
// using Argument = ReferenceGemm::Argument;
float Run(const Argument& arg)
{
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc);
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_,
arg.b_k_n_,
acc_m_n,
arg.a_element_op_,
arg.b_element_op_,
element_wise::PassThrough{});
// gemm
ref_invoker.Run(ref_argument);
// activation(acc + bias)
acc_m_n.ForEach([&](auto& self, auto idx) {
AccDataType out;
arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1]));
self(idx[0], idx[1]) = out;
});
// add from other layers
acc_m_n.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]);
});
// layernorm
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_);
// elementwise op
arg.c_m_n_.ForEach([&](auto& self, auto idx) {
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
});
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<C0DataType>& c0_m_n_add, // 1xN
const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<C0DataType>& c0_n_beta, // 1xN
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
{
return Argument{a_m_k,
b_k_n,
c_m_n,
c0_n_bias,
c0_m_n_add,
c0_n_gamma,
c0_n_beta,
a_element_op,
b_element_op,
acc_element_op,
c_element_op,
epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmLayernorm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M, int M,
int N, int N,
int K, int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
......
...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
bool pass = ck::profiler:: bool pass = ck::profiler::
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification, do_verification,
...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M, M,
N, N,
K, K,
(StrideA < 0) ? DefaultStrideA : StrideA, BatchStrideA,
(StrideB < 0) ? DefaultStrideB : StrideB, BatchStrideB,
(StrideC < 0) ? DefaultStrideC : StrideC, BatchStrideC,
StrideA_,
StrideB_,
StrideC_,
BatchCount); BatchCount);
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -25,19 +25,19 @@ int main() ...@@ -25,19 +25,19 @@ int main()
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, BatchCount); true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, BatchCount); true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, BatchCount); true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, BatchCount); true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1; return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment