Commit 7fefc966 authored by rocking's avatar rocking
Browse files

Rename kernel. Prepare to add second half kernel

parent b1727e6b
......@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_welford_xdl_cshuffle(
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
......@@ -255,7 +255,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelford_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
......@@ -478,7 +478,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_welford_xdl_cshuffle<
const auto kernel = kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
......
......@@ -82,7 +82,7 @@ template <typename ABDataType,
index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
index_t FGTransferScalarPerVector,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelford_xdl_cshuffle
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment