Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b86b318b
Commit
b86b318b
authored
Jun 15, 2022
by
Anthony Chang
Browse files
clean up; add comment
parent
54d032b0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
3 deletions
+11
-3
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
..._gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
+7
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
...eration/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+1
-1
No files found.
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
View file @
b86b318b
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
...
@@ -17,6 +16,13 @@
...
@@ -17,6 +16,13 @@
#include "reference_gemm_layernorm.hpp"
#include "reference_gemm_layernorm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel
//
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
b86b318b
...
@@ -22,6 +22,8 @@ namespace device {
...
@@ -22,6 +22,8 @@ namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// version currently has compiler issues with register spill which further causes validation
// failures.
// failures.
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
b86b318b
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
namespace
ck
{
namespace
ck
{
// D = Layernorm(A * B + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
// D = Layernorm(
acc_element_op(
A * B + broadcast(bias))
+ add)
* broadcast(gamma) + broadcast(beta)
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
b86b318b
...
@@ -9,6 +9,7 @@ namespace ck {
...
@@ -9,6 +9,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
@@ -28,7 +29,6 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -28,7 +29,6 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
BElementwiseOperation
,
BElementwiseOperation
,
element_wise
::
PassThrough
>
;
element_wise
::
PassThrough
>
;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
>
static
void
RunLayernorm
(
Tensor
<
OutDataType
>&
result
,
static
void
RunLayernorm
(
Tensor
<
OutDataType
>&
result
,
const
Tensor
<
ComputeDataType
>&
acc
,
// MxN
const
Tensor
<
ComputeDataType
>&
acc
,
// MxN
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment