Commit b86b318b authored by Anthony Chang's avatar Anthony Chang
Browse files

clean up; add comment

parent 54d032b0
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
...@@ -17,6 +16,13 @@ ...@@ -17,6 +16,13 @@
#include "reference_gemm_layernorm.hpp" #include "reference_gemm_layernorm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel
//
// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -22,6 +22,8 @@ namespace device { ...@@ -22,6 +22,8 @@ namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation // version currently has compiler issues with register spill which further causes validation
// failures. // failures.
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
namespace ck { namespace ck {
// D = Layernorm(A * B + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
......
...@@ -9,6 +9,7 @@ namespace ck { ...@@ -9,6 +9,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -28,7 +29,6 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -28,7 +29,6 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
BElementwiseOperation, BElementwiseOperation,
element_wise::PassThrough>; element_wise::PassThrough>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result, static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN const Tensor<ComputeDataType>& acc, // MxN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment