Commit 4939ee59 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 457308e3 87f2bbcf
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -13,10 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -13,10 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
set(target 0) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
set(target 1) endif()
endif()
endforeach()
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -27,6 +27,21 @@ ...@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2 #define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif #endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
...@@ -148,6 +163,10 @@ ...@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 #define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2 // experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 #define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...@@ -173,6 +192,10 @@ ...@@ -173,6 +192,10 @@
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
namespace ck { namespace ck {
...@@ -35,10 +35,11 @@ struct BlockwiseWelford ...@@ -35,10 +35,11 @@ struct BlockwiseWelford
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename CountDataType>
__device__ static inline void __device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
{ {
int count = count_a + count_b; CountDataType count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count; T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a; T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count; mean_a += delta * count_b_over_count;
...@@ -46,11 +47,12 @@ struct BlockwiseWelford ...@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a = count; count_a = count;
} }
__device__ static void Run(T& mean_value, T& var_value, int& count) template <typename CountDataType>
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
{ {
__shared__ T mean_block_buf[BlockSize]; __shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize]; __shared__ CountDataType count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
...@@ -78,11 +80,11 @@ struct BlockwiseWelford ...@@ -78,11 +80,11 @@ struct BlockwiseWelford
T mean1 = mean_block_buf[offset1]; T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1]; T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1]; CountDataType count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2]; T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2]; T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2]; CountDataType count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2); Merge(mean1, var1, count1, mean2, var2, count2);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
......
...@@ -786,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -786,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 && if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0) arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
{ {
std::cout << "first" << std::endl;
return false; return false;
} }
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1) if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
{ {
std::cout << "second" << std::endl;
return false; return false;
} }
} }
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto grid_desc_m_g = const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -133,8 +135,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -133,8 +135,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto reduceLength = blkGroupSize; const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 // we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 128) if(testBlkGroupSize <= 16)
break; break;
iterations++; iterations++;
...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_; void* workspace_mean_;
void* workspace_variance_; void* workspace_variance_;
void* workspace_count_; void* workspace_count_;
void* control_;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
} }
return (workspace_size); return (workspace_size);
...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{ {
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count // setup buffer used for intermediate welfor count
pArg_->workspace_count_ = pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz; reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
}; };
}; };
...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ = using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType, GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType, AccDataType,
...@@ -441,9 +487,67 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -441,9 +487,67 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize, BiasSrcVectorSize,
MeanVarSrcDstVectorSize>; MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration = // It is found that:
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
// 2) Profiler on gfx908 could hang even though it works when running examples
// 3) Single-kernel method works on gfx1100, but the performance it not better
// than two-kernel method (due to more warps participating the barrier)
if(ck::get_device_name() == "gfx90a")
{
const auto kern_multiblock_batchnorm_fwd_ =
kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_batchnorm_fwd_,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
// workspace by multiple workgroups
mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
// workspace by each workgroup
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
static_cast<int*>(arg.control_),
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
const auto kern_multiblock_welford_first_half = const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_, kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType, XDataType,
...@@ -467,8 +571,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -467,8 +571,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
ScaleBiasMeanVarGridDesc_M, ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>; ScaleBiasMeanVarGridDesc_M>;
avg_time += avg_time += launch_and_time_kernel(
launch_and_time_kernel(stream_config, stream_config,
kern_multiblock_welford_first_half, kern_multiblock_welford_first_half,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(BlockSize), dim3(BlockSize),
...@@ -482,8 +586,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -482,8 +586,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static_cast<MeanVarDataType*>(arg.workspace_variance_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_)); static_cast<int32_t*>(arg.workspace_count_));
avg_time += avg_time += launch_and_time_kernel(
launch_and_time_kernel(stream_config, stream_config,
kern_welford_second_half_batchnorm_forward_final, kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(BlockSize), dim3(BlockSize),
...@@ -496,7 +600,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -496,7 +600,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
arg.mean_var_grid_desc_m_, arg.mean_var_grid_desc_m_,
arg.blkGroupSize_, arg.blkGroupSize_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
numMeanVarCountBlockTileIteration,
arg.epsilon_, arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
...@@ -513,6 +616,7 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -513,6 +616,7 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
arg.saveMeanInvVariance_, arg.saveMeanInvVariance_,
arg.resultSaveMean_, arg.resultSaveMean_,
arg.resultSaveInvVariance_); arg.resultSaveInvVariance_);
};
} }
else else
{ {
......
...@@ -76,7 +76,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -76,7 +76,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams. // TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1; static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler(); static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
static constexpr PipelineVersion PipelineVer = PipelineVersion::v2; static constexpr PipelineVersion PipelineVer = PipelineVersion::v1;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
......
...@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_k_wos_lengths[0]}, num_group_{a_g_n_k_wos_lengths[0]},
num_gemm_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
// number of GEMM
num_gemm_ = YTilde * XTilde;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{ {
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
...@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void Print() const void Print() const
{ {
for(index_t i = 0; i < num_gemm_; i++) for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
std::cout << "a_grid_desc_ak0_m_ak1_container_" std::cout << "a_grid_desc_ak0_m_ak1_container_"
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl; << a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
...@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition // tensor descriptor for problem definition
index_t num_group_; index_t num_group_;
index_t num_gemm_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_; std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_; std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_; std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
...@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float ave_time = 0; float ave_time = 0;
for(index_t i = 0; i < arg.num_gemm_; i++) for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i], arg.b_grid_desc_n_k_container_[i],
...@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector load for A matrix from global memory to LDS // vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>) if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{ {
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{ {
...@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector store for E // vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>) if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
{ {
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -81,6 +82,36 @@ struct PassThrough ...@@ -81,6 +82,36 @@ struct PassThrough
y = x; y = x;
} }
#endif #endif
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
{
y = type_convert<f8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
{
y = type_convert<f8_t>(x);
}
}; };
struct UnaryConvert struct UnaryConvert
...@@ -109,6 +140,23 @@ struct ConvertBF16RTN ...@@ -109,6 +140,23 @@ struct ConvertBF16RTN
} }
}; };
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct Scale struct Scale
{ {
__host__ __device__ Scale(float scale) : scale_(scale) {} __host__ __device__ Scale(float scale) : scale_(scale) {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment