Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d62cd096
Commit
d62cd096
authored
Dec 20, 2022
by
rocking
Browse files
Add base class for gemm layernorm
parent
8572bc7c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
9 deletions
+90
-9
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16_welford.cpp
...yernorm/gemm_bias_relu_add_layernorm_xdl_fp16_welford.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp
...operation/gpu/device/device_gemm_multiple_d_layernorm.hpp
+67
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+22
-8
No files found.
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16_welford.cpp
View file @
d62cd096
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp
0 → 100644
View file @
d62cd096
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : H[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// H = layernorm(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E)
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
HLayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
HDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
HElementwiseOperation
>
struct
DeviceGemmMultipleDLayernorm
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_h
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideH
,
double
epsilon
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
HElementwiseOperation
h_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
d62cd096
...
@@ -10,13 +10,13 @@
...
@@ -10,13 +10,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -236,7 +236,21 @@ template <typename ALayout,
...
@@ -236,7 +236,21 @@ template <typename ALayout,
index_t
LayernormBetaSrcVectorSize
,
index_t
LayernormBetaSrcVectorSize
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
DeviceGemmMultipleDLayernorm
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
HElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
ELayout
=
HLayout
;
using
ELayout
=
HLayout
;
...
@@ -464,7 +478,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -464,7 +478,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideH
,
index_t
StrideH
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
CDEElementwiseOperation
cde_element_op
,
...
@@ -505,7 +519,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -505,7 +519,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
NRaw_
{
NRaw
},
NRaw_
{
NRaw
},
KRaw_
{
KRaw
},
KRaw_
{
KRaw
},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
static_cast
<
AccDataType
>
(
epsilon
)
}
{
{
gemm_mean_var_grid_desc_m_nblock_
=
gemm_mean_var_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
...
@@ -931,7 +945,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -931,7 +945,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideH
,
index_t
StrideH
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
CDEElementwiseOperation
cde_element_op
,
...
@@ -973,11 +987,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -973,11 +987,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideH
,
index_t
StrideH
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
CDEElementwiseOperation
cde_element_op
,
HElementwiseOperation
h_element_op
)
HElementwiseOperation
h_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_b
,
...
@@ -1000,7 +1014,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -1000,7 +1014,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment