Unverified Commit 72b7ae25 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Update staging branch. (#706)



* update daily build from rocm 5.4.3 to 5.5 (#693)

* Fix grouped_gemm_splitk kernels on MI300. (#694)

* replace amd_buffer_atomic_add with hip_atomic_add

* fix grouped_gemm_splitk kernels on mi300

* fix syntax

* revert experimental atomic_add changes

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>

* Fix the group of quantization_int8 kernels on MI300. (#695)

* replace amd_buffer_atomic_add with hip_atomic_add

* fix grouped_gemm_splitk kernels on mi300

* fix syntax

* revert experimental atomic_add changes

* fix the group of kernels from ticket 723 on MI300

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>

* Optimize bf16 conversion (#664)

* Add TypeConvert class and start refactoring

* Refactor TypeConvert as a struct

* Get back to template functions type_convert

* Add a type_convert_bf16_rtn, set rtz as default

* Clean up

* Add UnaryConvertPrecision struct for high-precision workloads

* Format

* Update type_convert to UnaryConvert on threadwise level

* Update UnaryConvertPrecision

* Format

* Fix chmod

* Add a flag to pick converion method

* Format

* Remove the added flag

* Merge elementwise op with type conversion

* Move type_convert to elemwise op, update the op

* Update type_convert_precision -> bf16_convert_rtn

* Clean up

* Update comments

* Update the CK_WORKAROUND_DENORM_FIX flag handling

* Update the unneeded op to work but warn user

* Remove the message

* Use a PassThrough instead of ConvertBF16RTN to calcaulate reference

* Format

* Add missing include

* Normalization/split k (#615)

* Add contraction profiler and tests (#701)

* Add contraction profiler and tests

* Build and style fixes

* Allow to use any elementwise operator for ref_contraction

* Introduce profile_contraction_scale and profile_contraction_bilinear

* Make ref_contraction generic and extend interface tests

* Stylistic minor fixes

* Extend test_contraction_interface

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>
parent bbe74503
......@@ -10,8 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -20,6 +19,10 @@ namespace tensor_operation {
namespace device {
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
......@@ -68,7 +71,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int blkGroupSize,
int numBlockTileIteration)
{
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
......@@ -117,10 +119,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
......@@ -132,7 +133,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
struct Argument : public BaseArgument
{
......@@ -162,26 +163,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
long_index_t invariant_total_length;
long_index_t reduce_total_length;
long_index_t invariant_length;
long_index_t reduce_length;
std::tie(invariant_total_length, reduce_total_length) =
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize);
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_;
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize);
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_);
MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
......@@ -202,7 +199,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
YElementwiseOperation y_elementwise_op_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
......@@ -286,6 +282,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0)
return false;
};
}
else
......@@ -295,12 +294,12 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
return false;
};
if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
{
return false;
}
if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
{
return false;
}
};
// if fastest dim is not reduced
if constexpr(GammaSrcVectorDim == 0)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseWelford,
typename XDataType,
typename MeanVarDataType,
typename ComputeDataType,
typename XGridDesc_M_K,
typename MeanVarGridDesc_M_KBlock>
__global__ void
kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x_global,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count)
{
GridwiseWelford::Run(x_grid_desc_m_k,
mean_var_grid_desc_m_kblock,
num_k_block_tile_iteration,
p_x_global,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename GridwiseWelfordNormalization,
typename MeanVarDataType,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K>
__global__ void
kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
const CountGridDesc_M_KBlock count_grid_desc_m_kblock,
const XYGammaBetaGridDesc_M_K x_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K y_grid_desc_m_k,
index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration,
index_t k_grid_size,
ComputeDataType epsilon,
const MeanVarDataType* const p_mean_global,
const MeanVarDataType* const p_variance_global,
const int32_t* const p_welford_count_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock,
count_grid_desc_m_kblock,
x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_mean_var_count_iteration,
num_k_block_tile_iteration,
k_grid_size,
epsilon,
p_mean_global,
p_variance_global,
p_welford_count_global,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
y_elementwise_op);
};
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XYVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
{
using MeanVarDataType = ComputeDataType;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int kBlockSize,
int numBlockTileIteration)
{
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K)
{
const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
}
template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeCountDescriptor_M_K(index_t M, index_t K)
{
const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
}
using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using Kernel1MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1));
using Kernel2MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using Kernel2CountGridDesc_M_KBlock =
decltype(MakeCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType,
ComputeDataType,
MeanVarDataType,
SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYVectorDim,
XSrcVectorSize>;
using GridwiseWelfordNormalization =
GridwiseNormalizationSplitK2nd<MeanVarDataType,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYVectorDim,
YDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: p_x_(p_x),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
y_elementwise_op_(y_elementwise_op)
{
epsilon_ = static_cast<ComputeDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = 1;
while(true)
{
int testKGridSize =
math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_);
// we want the kGridSize_ be not more than 128
if(testKGridSize <= 128)
break;
++numBlockTileIteration_;
};
kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_;
// We do not use vector load for mean, var and count
static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize;
numMeanVarCountIteration_ =
math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize);
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(MRaw_,
kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
}
ComputeDataType epsilon_;
const XDataType* p_x_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
YDataType* p_y_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
YElementwiseOperation y_elementwise_op_;
int kGridSize_;
int numMeanVarCountIteration_;
int numBlockTileIteration_;
size_t gridSize_;
SrcGridDesc_M_K x_grid_desc_m_k_;
SrcGridDesc_M_K gamma_grid_desc_m_k_;
SrcGridDesc_M_K beta_grid_desc_m_k_;
SrcGridDesc_M_K y_grid_desc_m_k_;
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_;
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr ||
arg.p_workspace_count_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford,
XDataType,
MeanVarDataType,
ComputeDataType,
SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock>;
auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization,
MeanVarDataType,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel1,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.kernel1_mean_var_grid_desc_m_kblock_,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_),
static_cast<MeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_));
avg_time += launch_and_time_kernel(stream_config,
kernel2,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.kernel2_mean_var_grid_desc_m_kblock_,
arg.kernel2_count_grid_desc_m_kblock_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.numMeanVarCountIteration_,
arg.numBlockTileIteration_,
arg.kGridSize_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_),
static_cast<MeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.y_elementwise_op_);
return avg_time;
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
// workspace for welford intermediate mean
workspace_size += welford_size * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size += welford_size * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64;
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = welford_size * sizeof(MeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = welford_size * sizeof(MeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return false;
}
else
{
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0)
return false;
};
}
else
{
if(p_arg_->xStrides_[Rank - 1] != 1)
return false;
if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
return false;
if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
return false;
};
// if fastest dim is not reduced
if constexpr(GammaSrcVectorDim == 0)
{
if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return false;
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return false;
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return false;
}
// if fastest dim is not reduced
if constexpr(BetaSrcVectorDim == 0)
{
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
return false;
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return false;
if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
return false;
}
if(p_arg_->kGridSize_ <= 1)
return false;
return true;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
double epsilon,
const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_saveMean,
void* p_saveInvVar,
YElementwiseOperation y_elementwise_op) override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths,
xStrides,
gammaStrides,
betaStrides,
yStrides,
reduceDims,
y_elementwise_op,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationSplitKImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -56,6 +56,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......@@ -86,6 +92,23 @@ struct UnaryConvert
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
......
......@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
......
......@@ -266,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
......
......@@ -136,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
......
......@@ -3,8 +3,8 @@
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp"
namespace ck {
template <typename GridwiseReduction,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename XDataType,
typename ComputeDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarGridDesc_M_KBlock,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize>
struct GridwiseNormalizationSplitK1st
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block)
{
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock;
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
int thread_max_len =
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta;
});
}
return kPerThread;
}
else
{
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
}
}
// Calculate mean and variance by welford along k dimension
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x_global,
MeanVarDataType* const p_mean_global,
MeanVarDataType* const p_variance_global,
int32_t* const p_welford_count_global)
{
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1);
const index_t block_m_cluster_id = block_global_id / k_grid_size;
const index_t block_k_cluster_id = block_global_id % k_grid_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(
block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize));
auto mean_var_count_store_index = make_multi_index(
block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id);
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarGridDesc_M_KBlock,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ =
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
});
}
int welford_count = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
// The value of count is same for all I
if constexpr(I == MThreadSliceSize - 1)
welford_count = count;
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
mean_thread_buf,
mean_var_grid_desc_m_kblock,
mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
var_thread_buf,
mean_var_grid_desc_m_kblock,
var_global_val_buf);
if(block_m_cluster_id == 0 && thread_m_cluster_id == 0)
p_welford_count_global[block_k_cluster_id] = welford_count;
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename MeanVarDataType,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize>
struct GridwiseNormalizationSplitK2nd
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadWelfordDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
const CountGridDesc_M_KBlock& count_grid_desc_m_kblock,
const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration,
index_t k_grid_size,
ComputeDataType epsilon,
const MeanVarDataType* const p_mean_global,
const MeanVarDataType* const p_variance_global,
const int32_t* const p_welford_count_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
{
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t block_m_cluster_id = block_global_id / k_grid_size;
const index_t block_k_cluster_id = block_global_id % k_grid_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// Global Memory
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
const auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto gamma_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto& beta_thread_buf = gamma_thread_buf;
auto& y_thread_buf = x_thread_buf;
// IO
auto threadwise_mean_var_load_m_kblock =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_KBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_grid_desc_m_kblock,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id));
auto threadwise_count_load_m_kblock =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
CountGridDesc_M_KBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
count_grid_desc_m_kblock,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
XYGammaBetaGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
XYGammaBetaGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_m_k,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
ComputeDataType,
XYGammaBetaGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
BetaSrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
thread_k_cluster_id * BetaSrcVectorSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGammaBetaGridDesc_M_K,
YElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
YDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
// step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_k =
make_multi_index(I0, KThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k)
{
threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_mean_thread_buf);
threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_var_thread_buf);
threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_mean_thread_buf,
in_var_thread_buf,
in_welford_count_thread_buf,
mean_thread_buf,
var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow(
mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k);
threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock,
mean_var_count_thread_copy_step_I0_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
});
// step2: normalization
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
} // end for (normalization)
}
};
} // namespace ck
......@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
......@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v;
src_element_op_(src_v, src_vector_container.template AsType<SrcData>()[i]);
src_vector_container.template AsType<SrcData>()(i) = src_v;
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
......@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
// TODO type_convert is not used yet!!!!!
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
......@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<num_dst_vector>{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
}
static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}
......
......@@ -976,37 +976,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
......@@ -1064,6 +1033,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T>
struct NumericLimits
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
c_ms_ns_{c_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<CDataType>& c_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(ck::index_t k0 = 0; k0 < K0; ++k0)
{
for(ck::index_t k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op)
{
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
......@@ -66,8 +67,26 @@ struct ReferenceGemm : public device::BaseOperator
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
......
......@@ -46,3 +46,33 @@ out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
....
Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
```
## Profile contraction kernels
```bash
#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
#arg2: data type (0: fp32; 1: f64)\n"
#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
#arg4: verification (0: no; 1: yes)
#arg5: initialization (0: no init; 1: integer value; 2: decimal value)
#arg6: print tensor value (0: no; 1: yes)
#arg7: time kernel (0: no, 1: yes)
#arg8 and arg9: alpha and beta
#arg10 to 15: M0, M1, N0, N1, K0, K1
#arg16 to 31: Strides for A, B, D and E (skip for default)
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
```
Result (MI100)
```bash
a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384}
d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
....
Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include <limits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace profiler {
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
template <typename ALayout,
typename BLayout,
typename CDELayout,
typename DataType,
typename DTupleDataType,
typename CDElementOp>
int profile_contraction_impl(ck::index_t do_verification,
ck::index_t init_method,
bool do_log,
bool time_kernel,
CDElementOp cde_element_op,
const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA,
const std::vector<ck::index_t>& StridesB,
const std::vector<ck::index_t>& StridesE,
const std::vector<ck::index_t>& StridesD)
{
bool pass = true;
auto f_host_tensor_descriptor = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23,
const std::vector<ck::index_t>& strides) {
std::vector<std::size_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
std::vector<std::size_t> strides_szt(strides.begin(), strides.end());
return HostTensorDescriptor(dims_szt, strides);
};
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data());
const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
constexpr ck::index_t NumDim = 2;
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
NumDim,
NumDim,
DataType,
DataType,
DTupleDataType,
DataType,
AElementOp,
BElementOp,
CDElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference op
if(do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
NumDim,
NumDim,
DataType,
DataType,
DataType,
DataType,
AElementOp,
BElementOp>;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1)
{
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1),
d_m_n(m0, m1, n0, n1));
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1));
}
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
}
}
}
}
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr;
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths,
StridesA,
b_ns_ks_lengths,
StridesB,
std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
std::array<std::vector<ck::index_t>, 1>{StridesD},
e_ms_ns_lengths,
StridesE,
a_element_op,
b_element_op,
cde_element_op);
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{},
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths,
StridesA,
b_ns_ks_lengths,
StridesB,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths,
StridesE,
a_element_op,
b_element_op,
cde_element_op);
}
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
auto invoker_ptr = op_ptr->MakeInvokerPointer();
auto nelems_m = M[0] * M[1];
auto nelems_n = N[0] * N[1];
auto nelems_k = K[0] * K[1];
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
e_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nelems_m * nelems_n * nelems_k;
std::size_t num_btype = sizeof(DataType) * nelems_m * nelems_k +
sizeof(DataType) * nelems_k * nelems_n +
sizeof(DataType) * nelems_m * nelems_n;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float threshold =
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
pass = pass & ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result,
"Error: incorrect results!",
threshold,
threshold);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if constexpr(is_same<DataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<DataType, double>::value)
{
std::cout << "Best Perf for datatype = f64";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
if constexpr(is_same<CDELayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " CDELayout = RowMajor";
}
else if constexpr(is_same<CDELayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " CDELayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StridesA = " << StridesA
<< " StridesB = " << StridesB << " StridesE = " << StridesE << " : " << best_avg_time
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/ck.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
enum struct ContractionMatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct ContractionDataType
{
F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1
};
inline void collect_index_params(char* argv[],
std::vector<ck::index_t>& params,
const ck::index_t from,
const ck::index_t num)
{
for(ck::index_t p = from; p < from + num; p++)
params.push_back(std::stoi(argv[p]));
}
// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1}
// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1}
inline void
assign_default_strides(Row, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1};
}
inline void
assign_default_strides(Col, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]};
}
......@@ -30,6 +30,8 @@ set(PROFILER_SOURCES
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_gemm_fastgelu.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
)
set(PROFILER_EXECUTABLE ckProfiler)
......@@ -70,4 +72,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <vector>
#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp"
#define OP_NAME "contraction_bilinear"
#define OP_DESC "CONTRACTION+Bilinear"
static void print_helper_msg()
{
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8 and arg9: alpha and beta\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< std::endl;
}
int profile_contraction_bilinear(int argc, char* argv[])
{
const bool default_strides = argc == 16;
if(argc != 32 && argc != 16)
{
print_helper_msg();
exit(1);
}
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const ck::index_t init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]);
const float beta = std::stof(argv[9]);
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2);
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesE;
std::vector<ck::index_t> StridesD;
if(!default_strides)
{
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
}
using F32 = float;
using F64 = double;
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CDELayout = decltype(cde_layout);
using DataType = decltype(type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F64{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment