Commit 6c496076 authored by Anthony Chang's avatar Anthony Chang
Browse files

activation in correct order

parent 93235bb4
...@@ -48,7 +48,8 @@ struct Relu ...@@ -48,7 +48,8 @@ struct Relu
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
// Elementwise operation that operates on the output of matrix multiplication Acc = A * B // Elementwise operation that operates on the output of matrix multiplication
// i.e., AccElementOp(A * B + bias)
using AccElementOp = Relu; using AccElementOp = Relu;
// Elementwise operation that operates on the output of layer normalization // Elementwise operation that operates on the output of layer normalization
using CElementOp = Relu; using CElementOp = Relu;
...@@ -227,15 +228,16 @@ int main(int argc, char* argv[]) ...@@ -227,15 +228,16 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// extra 5MN flops due to: bias + gamma + beta + norm_sub + norm_div, // extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div,
// excluding reduction steps // excluding reduction steps
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(5) * M * N; std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + // extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN)
sizeof(CDataType) * M * N + sizeof(CDataType) * 3 * N; std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
......
...@@ -624,7 +624,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -624,7 +624,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
AccElementwiseOperation, tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
I1, I1,
...@@ -648,7 +648,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -648,7 +648,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
m_thread_data_on_block_idx[I3], m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4], m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
acc_element_op}; tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
...@@ -883,41 +883,43 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -883,41 +883,43 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
block_sync_lds(); block_sync_lds();
// layernorm // load from LDS and global, add bias
{ c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
// load from LDS and global, add bias c_shuffle_block_buf,
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_desc_mperblock_nperblock,
c_shuffle_block_buf, make_tuple(I0, I0),
c_reduce_thread_desc_mperblock_nperblock, c_reduce_thread_buf);
make_tuple(I0, I0),
c_reduce_thread_buf); c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_thread_copy_global_to_vgpr.Run( c0_bias_grid_buf,
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf, make_tuple(I0, I0, I0, I0),
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, c0_thread_buf);
make_tuple(I0, I0, I0, I0),
c0_thread_buf); static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( FloatReduceAcc out;
[&](auto i) { acc_element_op(out, c_reduce_thread_buf(i) +
c_reduce_thread_buf(i) += static_cast<FloatReduceAcc>(c0_thread_buf(i)));
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // bias c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias)
}); });
c0_add_thread_copy_global_to_vgpr.Run(
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_add_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( c0_add_thread_copy_global_to_vgpr.Run(
[&](auto i) { c_grid_desc_mblock_mperblock_nblock_nperblock,
c_reduce_thread_buf(i) += c0_add_grid_buf,
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // add c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
}); make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) +=
static_cast<FloatReduceAcc>(c0_thread_buf(i)); // add
});
// layernorm
{
using ThreadwiseReduceD0 = using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
......
...@@ -26,20 +26,17 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -26,20 +26,17 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AccDataType, AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation>; element_wise::PassThrough>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result, static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& add, // MxN
const Tensor<InDataType>& gamma, // 1xN const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5) const InDataType epsilon = 1e-5)
{ {
assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] && assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
...@@ -47,17 +44,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -47,17 +44,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> acc_layernorm(acc.mDesc); Tensor<ComputeDataType> acc_layernorm(acc);
// add bias
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) = acc(idx[0], idx[1]) + bias(idx[1]);
});
// add from other layer
acc_layernorm.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) += add(idx[0], idx[1]);
});
// reduce N dim // reduce N dim
for(size_t i = 0; i < M; i++) for(size_t i = 0; i < M; i++)
...@@ -152,13 +139,25 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -152,13 +139,25 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
acc_m_n, acc_m_n,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_); element_wise::PassThrough{});
// gemm // gemm
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// activation(acc + bias)
acc_m_n.ForEach([&](auto& self, auto idx) {
AccDataType out;
arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1]));
self(idx[0], idx[1]) = out;
});
// add from other layers
acc_m_n.ForEach([&](auto& self, auto idx) {
self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]);
});
// layernorm // layernorm
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_m_n_add_, arg.c0_n_gamma_, arg.c0_n_beta_); RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_);
// elementwise op // elementwise op
arg.c_m_n_.ForEach([&](auto& self, auto idx) { arg.c_m_n_.ForEach([&](auto& self, auto idx) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment