Commit a3e487ca authored by Anthony Chang's avatar Anthony Chang
Browse files

add description in example code

parent f1b2e521
......@@ -5,11 +5,19 @@ Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is define
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Input:
Computation graph:
Q, K, V, Y, dY, and per-row softmax stats computed beforehand during forward prop
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Outputs:
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
......@@ -37,6 +45,7 @@ Outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment