Commit 961f5e9e authored by rocking's avatar rocking
Browse files

Merge branch 'gemm_layernorm_welford' of...

Merge branch 'gemm_layernorm_welford' of https://github.com/ROCmSoftwarePlatform/composable_kernel into gemm_layernorm_welford
parents 8b2796c7 b1d07e4a
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -100,6 +101,23 @@ __global__ void ...@@ -100,6 +101,23 @@ __global__ void
#endif #endif
} }
template <typename GridwiseWelfordLayernorm,
typename XDataType,
typename YDataType,
typename MeanDataType,
typename VarDataType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(const XDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
{
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid);
}
} // namespace ck } // namespace ck
namespace ck { namespace ck {
...@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, FDataType, GDataType>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
// TODO float avg_time = 0;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
...@@ -480,7 +502,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -480,7 +502,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle< const auto kernel_gemm_welford =
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
...@@ -499,8 +522,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -499,8 +522,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, const auto kernel_welford_layernorm =
kernel, kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
FDataType,
GDataType>;
avg_time +=
launch_and_time_kernel(stream_config,
kernel_gemm_welford,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -520,6 +551,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -520,6 +551,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.f_grid_desc_mblock_mperblock_nblock_, arg.f_grid_desc_mblock_mperblock_nblock_,
arg.g_grid_desc_mblock_mperblock_nblock_, arg.g_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_f_grid_,
arg.p_g_grid_,
arg.p_h_grid_);
return avg_time;
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
template <typename XDataType, typename YDataType, typename MeanDataType, typename VarDataType>
struct GridwiseWelfordSecondHalfLayernorm2d
{
__device__ static void Run(const XDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
{
ignore = p_x_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = p_y_grid;
} // run
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment