Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b1d07e4a
Commit
b1d07e4a
authored
Oct 05, 2022
by
rocking
Browse files
Add second kernel for gemm+layernorm
parent
c3107fd5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
41 deletions
+121
-41
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+84
-41
include/ck/tensor_operation/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp
...ion/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp
+37
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
b1d07e4a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -100,6 +101,23 @@ __global__ void
...
@@ -100,6 +101,23 @@ __global__ void
#endif
#endif
}
}
template
<
typename
GridwiseWelfordLayernorm
,
typename
XDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
VarDataType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_welford_layernorm2d_second_half
(
const
XDataType
*
__restrict__
p_x_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
YDataType
*
__restrict__
p_y_grid
)
{
GridwiseWelfordLayernorm
::
Run
(
p_x_grid
,
p_mean_grid
,
p_var_grid
,
p_y_grid
);
}
}
// namespace ck
}
// namespace ck
namespace
ck
{
namespace
ck
{
...
@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -309,6 +327,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
GridwiseWelfordLayernorm
=
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
HDataType
,
FDataType
,
GDataType
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -459,7 +480,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
// TODO
float
avg_time
=
0
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
...
@@ -480,46 +502,67 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -480,46 +502,67 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle
<
const
auto
kernel_gemm_welford
=
GridwiseGemm
,
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle
<
ADataType
,
// TODO: distiguish A/B datatype
GridwiseGemm
,
typename
GridwiseGemm
::
DsGridPointer
,
ADataType
,
// TODO: distiguish A/B datatype
EDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
FDataType
,
EDataType
,
GDataType
,
FDataType
,
AElementwiseOperation
,
GDataType
,
BElementwiseOperation
,
AElementwiseOperation
,
CDEElementwiseOperation
,
BElementwiseOperation
,
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
FGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
GGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
FGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemm
::
GGridDescriptor_MBlock_MPerBlock_NBlock
,
has_main_loop
>
;
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
const
auto
kernel_welford_layernorm
=
dim3
(
grid_size
),
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
dim3
(
BlockSize
),
EDataType
,
0
,
HDataType
,
arg
.
p_a_grid_
,
FDataType
,
arg
.
p_b_grid_
,
GDataType
>
;
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
avg_time
+=
arg
.
p_f_grid_
,
launch_and_time_kernel
(
stream_config
,
arg
.
p_g_grid_
,
kernel_gemm_welford
,
arg
.
a_element_op_
,
dim3
(
grid_size
),
arg
.
b_element_op_
,
dim3
(
BlockSize
),
arg
.
cde_element_op_
,
0
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
p_a_grid_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
p_b_grid_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
p_ds_grid_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
p_e_grid_
,
arg
.
f_grid_desc_mblock_mperblock_nblock_
,
arg
.
p_f_grid_
,
arg
.
g_grid_desc_mblock_mperblock_nblock_
,
arg
.
p_g_grid_
,
arg
.
block_2_etile_map_
);
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
f_grid_desc_mblock_mperblock_nblock_
,
arg
.
g_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_e_grid_
,
arg
.
p_f_grid_
,
arg
.
p_g_grid_
,
arg
.
p_h_grid_
);
return
avg_time
;
};
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
include/ck/tensor_operation/gpu/grid/gridwise_welford_second_half_layernorm2d.hpp
0 → 100644
View file @
b1d07e4a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace
ck
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
VarDataType
>
struct
GridwiseWelfordSecondHalfLayernorm2d
{
__device__
static
void
Run
(
const
XDataType
*
__restrict__
p_x_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
YDataType
*
__restrict__
p_y_grid
)
{
ignore
=
p_x_grid
;
ignore
=
p_mean_grid
;
ignore
=
p_var_grid
;
ignore
=
p_y_grid
;
}
// run
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment