Commit a3e487ca authored by Anthony Chang's avatar Anthony Chang
Browse files

add description in example code

parent f1b2e521
...@@ -5,13 +5,21 @@ Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is define ...@@ -5,13 +5,21 @@ Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is define
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Input: Computation graph:
Q, K, V, Y, dY, and per-row softmax stats computed beforehand during forward prop K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Outputs: Kernel inputs:
dQ, dK, dV Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/ */
...@@ -37,6 +45,7 @@ Outputs: ...@@ -37,6 +45,7 @@ Outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment