Commit b1d07e4a authored by rocking's avatar rocking
Browse files

Add second kernel for gemm+layernorm

parent c3107fd5
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -100,6 +101,23 @@ __global__ void ...@@ -100,6 +101,23 @@ __global__ void
#endif #endif
} }
template <typename GridwiseWelfordLayernorm,
typename XDataType,
typename YDataType,
typename MeanDataType,
typename VarDataType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(const XDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
{
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid);
}
} // namespace ck } // namespace ck
namespace ck { namespace ck {
...@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, FDataType, GDataType>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
// TODO float avg_time = 0;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
...@@ -480,46 +502,67 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -480,46 +502,67 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle< const auto kernel_gemm_welford =
GridwiseGemm, kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
ADataType, // TODO: distiguish A/B datatype GridwiseGemm,
typename GridwiseGemm::DsGridPointer, ADataType, // TODO: distiguish A/B datatype
EDataType, typename GridwiseGemm::DsGridPointer,
FDataType, EDataType,
GDataType, FDataType,
AElementwiseOperation, GDataType,
BElementwiseOperation, AElementwiseOperation,
CDEElementwiseOperation, BElementwiseOperation,
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1, CDEElementwiseOperation,
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1, typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock,
has_main_loop>; typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel, const auto kernel_welford_layernorm =
dim3(grid_size), kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
dim3(BlockSize), EDataType,
0, HDataType,
arg.p_a_grid_, FDataType,
arg.p_b_grid_, GDataType>;
arg.p_ds_grid_,
arg.p_e_grid_, avg_time +=
arg.p_f_grid_, launch_and_time_kernel(stream_config,
arg.p_g_grid_, kernel_gemm_welford,
arg.a_element_op_, dim3(grid_size),
arg.b_element_op_, dim3(BlockSize),
arg.cde_element_op_, 0,
arg.a_grid_desc_ak0_m_ak1_, arg.p_a_grid_,
arg.b_grid_desc_bk0_n_bk1_, arg.p_b_grid_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.p_ds_grid_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.p_e_grid_,
arg.f_grid_desc_mblock_mperblock_nblock_, arg.p_f_grid_,
arg.g_grid_desc_mblock_mperblock_nblock_, arg.p_g_grid_,
arg.block_2_etile_map_); arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.f_grid_desc_mblock_mperblock_nblock_,
arg.g_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_f_grid_,
arg.p_g_grid_,
arg.p_h_grid_);
return avg_time;
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
template <typename XDataType, typename YDataType, typename MeanDataType, typename VarDataType>
struct GridwiseWelfordSecondHalfLayernorm2d
{
__device__ static void Run(const XDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
{
ignore = p_x_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = p_y_grid;
} // run
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment