Unverified Commit 8c4897d1 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-756

parents 9ba9ebec 9e86ebd6
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto grid_desc_m_g = const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto reduceLength = blkGroupSize; const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 // we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 128) if(testBlkGroupSize <= 16)
break; break;
iterations++; iterations++;
...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_; void* workspace_mean_;
void* workspace_variance_; void* workspace_variance_;
void* workspace_count_; void* workspace_count_;
void* control_;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
} }
return (workspace_size); return (workspace_size);
...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{ {
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count // setup buffer used for intermediate welfor count
pArg_->workspace_count_ = pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz; reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
}; };
}; };
...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ = using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType, GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType, AccDataType,
...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize, BiasSrcVectorSize,
MeanVarSrcDstVectorSize>; MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration = // It is found that:
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const auto kern_multiblock_welford_first_half = // 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_, // 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType, // than two-kernel method (due to more warps participating the barrier)
MeanVarDataType, if(ck::get_device_name() == "gfx90a")
XYGridDesc_M_K, {
MeanVarCountGridDesc_M_G, const auto kern_multiblock_batchnorm_fwd_ =
GetReduceCountPerThreadFunctor>; kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
const auto kern_welford_second_half_batchnorm_forward_final = YDataType,
kernel_welford_second_half_batchnorm_forward_final< AccDataType,
GridwiseWelfordSecondHalfBatchNormForwardFinal_, ScaleDataType,
XDataType, BiasDataType,
YDataType, MeanVarDataType,
AccDataType, YElementwiseOp,
ScaleDataType, XYGridDesc_M_K,
BiasDataType, MeanVarCountGridDesc_M_G,
MeanVarDataType, MeanVarCountGridDesc_M_K,
YElementwiseOp, ScaleBiasMeanVarGridDesc_M,
XYGridDesc_M_K, ScaleBiasMeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, GetReduceCountPerThreadFunctor>;
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>; avg_time += launch_and_time_kernel(
stream_config,
avg_time += kern_multiblock_batchnorm_fwd_,
launch_and_time_kernel(stream_config, dim3(arg.gridSize_),
kern_multiblock_welford_first_half, dim3(BlockSize),
dim3(arg.gridSize_), 0,
dim3(BlockSize), arg.x_grid_desc_m_k_,
0, arg.y_grid_desc_m_k_,
arg.x_grid_desc_m_k_, mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
mean_var_count_grid_desc_m_g, // workspace by multiple workgroups
get_reduce_count_per_thread, mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
arg.numBlockTileIteration_, // workspace by each workgroup
arg.p_x_, arg.scale_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.bias_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_variance_), arg.mean_var_grid_desc_m_,
static_cast<int32_t*>(arg.workspace_count_)); get_reduce_count_per_thread,
arg.numBlockTileIteration_,
avg_time += arg.epsilon_,
launch_and_time_kernel(stream_config, arg.p_x_,
kern_welford_second_half_batchnorm_forward_final, static_cast<MeanVarDataType*>(arg.workspace_mean_),
dim3(arg.gridSize_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
dim3(BlockSize), static_cast<int32_t*>(arg.workspace_count_),
0, static_cast<int*>(arg.control_),
arg.x_grid_desc_m_k_, arg.p_scale_,
arg.y_grid_desc_m_k_, arg.p_bias_,
mean_var_count_grid_desc_m_k, arg.y_elementwise_op_,
arg.scale_grid_desc_m_, arg.p_y_,
arg.bias_grid_desc_m_, arg.updateMovingAverage_, // true or false
arg.mean_var_grid_desc_m_, arg.averageFactor_,
arg.blkGroupSize_, arg.resultRunningMean_,
arg.numBlockTileIteration_, arg.resultRunningVariance_,
numMeanVarCountBlockTileIteration, arg.saveMeanInvVariance_, // true or false
arg.epsilon_, arg.resultSaveMean_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.resultSaveInvVariance_);
static_cast<MeanVarDataType*>(arg.workspace_variance_), }
static_cast<int32_t*>(arg.workspace_count_), else
arg.p_x_, {
arg.p_scale_, const auto kern_multiblock_welford_first_half =
arg.p_bias_, kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
arg.y_elementwise_op_, XDataType,
arg.p_y_, MeanVarDataType,
arg.updateMovingAverage_, XYGridDesc_M_K,
arg.averageFactor_, MeanVarCountGridDesc_M_G,
arg.resultRunningMean_, GetReduceCountPerThreadFunctor>;
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, const auto kern_welford_second_half_batchnorm_forward_final =
arg.resultSaveMean_, kernel_welford_second_half_batchnorm_forward_final<
arg.resultSaveInvVariance_); GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
} }
else else
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
const std::array<index_t, Rank>& xyStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleXYLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
const auto tupleXYStrides =
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
Number<NumBatchNormReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(raw_grid_desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_g_padded =
transform_tensor_descriptor(grid_desc_m_g,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_g_padded);
};
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad =
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
const std::array<index_t, NumInvariantDim>& strides)
{
const auto tupleLengths =
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
const auto tupleStrides =
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_m = transform_tensor_descriptor(
raw_grid_desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded =
transform_tensor_descriptor(grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, mPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_m_padded);
};
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* p_scale,
const BiasDataType* p_bias,
const YElementwiseOp y_elementwise_op,
double epsilon,
YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_scale_(p_scale),
p_bias_(p_bias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y),
resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
xStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
yStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
std::tie(invariant_length_, reduce_length_) =
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
updateMovingAverage_ =
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
if(UseMultiblockInK)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 16)
break;
iterations++;
};
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration_ = iterations;
}
else
{
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
scale_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
bias_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
mean_var_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
}
AccDataType epsilon_;
AccDataType averageFactor_;
bool updateMovingAverage_;
bool saveMeanInvVariance_;
std::array<index_t, Rank> xyLengths_;
std::array<index_t, Rank> xStrides_;
std::array<index_t, Rank> yStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
const ScaleDataType* p_scale_;
const BiasDataType* p_bias_;
const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
MeanVarDataType* resultSaveMean_;
MeanVarDataType* resultSaveInvVariance_;
MeanVarDataType* resultRunningMean_;
MeanVarDataType* resultRunningVariance_;
long_index_t invariant_length_;
long_index_t reduce_length_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
XYGridDesc_M_K x_grid_desc_m_k_;
XYGridDesc_M_K y_grid_desc_m_k_;
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
void* workspace_mean_;
void* workspace_variance_;
void* workspace_count_;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// workspace for welford intermediate mean
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
}
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->workspace_variance_ =
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
index_t variance_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welfor count
pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
};
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
const auto mean_var_count_grid_desc_m_g =
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
const auto mean_var_count_grid_desc_m_k =
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize>;
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time +=
launch_and_time_kernel(stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time +=
launch_and_time_kernel(stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.numBlockTileIteration_, arg.reduce_length_);
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
GridwiseBatchNormForwardWithBlockwiseWelford_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(stream_config,
kern_batchnorm_fwd,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
return (avg_time);
};
float Run(const BaseArgument* pArg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* pArg) override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
if constexpr(XSrcYDstVectorDim == 0)
{
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
return false;
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
return false;
}
else
{
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
return false;
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
return false;
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
return false;
bool is_valid = true;
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
is_valid = false;
});
if(!is_valid)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const BiasDataType*>(p_bias),
y_elementwise_op,
epsilon,
static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>; ADataType,
BDataType,
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>; ADataType,
BDataType,
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -448,6 +455,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -448,6 +455,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg); return GridwiseGemm::CheckValidity(arg);
} }
......
...@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
...@@ -582,9 +590,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -582,9 +590,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
hipGetErrorString(hipMemset( hipGetErrorString(hipMemsetAsync(
arg.p_c_grid_, arg.p_c_grid_,
0, 0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType),
stream_config.stream_id_));
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
// vector load A/B matrix from global memory // vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
......
...@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -810,6 +810,11 @@ struct ...@@ -810,6 +810,11 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO ...@@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -59,6 +60,7 @@ template < ...@@ -59,6 +60,7 @@ template <
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
...@@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector,
GemmDlAlg>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
...@@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
true>; true,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
false>; false,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
true>; true,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
false>; false,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
if(ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
return false;
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
...@@ -492,10 +509,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -492,10 +509,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
} }
else return false;
{
return false;
}
} }
// polymorphic // polymorphic
...@@ -572,7 +586,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -572,7 +586,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
// polymorphic // polymorphic
std::string GetTypeString() const override virtual std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlgorithm::Dpp8>
{
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDlDpp8"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarGridDesc_M_NBlock = decltype( using GemmMeanVarGridDesc_M_NBlock =
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using GemmCountGridDesc_M_NBlock = decltype( using GemmCountGridDesc_M_NBlock =
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using LayernormMeanVarGridDesc_M_NBlock = using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>, decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
...@@ -855,9 +857,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -855,9 +857,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
...@@ -555,9 +557,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -555,9 +557,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename ABDataType, typename ADataType,
typename BDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -36,8 +37,8 @@ __global__ void ...@@ -36,8 +37,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid, kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -288,14 +293,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -288,14 +293,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
PipelineVer>; PipelineVer>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
...@@ -438,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -438,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
BDataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -491,9 +501,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -491,9 +501,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -645,6 +645,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -645,6 +645,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -65,7 +65,8 @@ template <typename ALayout, ...@@ -65,7 +65,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
PipelineVer>; PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
...@@ -188,9 +191,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -188,9 +191,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment