"docs/vscode:/vscode.git/clone" did not exist on "3fff964d09b067637355524d99dd4e0365c0ef10"
Commit 61f4a7ee authored by Anthony Chang's avatar Anthony Chang
Browse files

implement scaling

parent aa0ee8e2
......@@ -212,11 +212,6 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
......@@ -249,8 +244,7 @@ int run(int argc, char* argv[])
ck::index_t G0 = 3;
ck::index_t G1 = 2;
// float alpha = 1.f / std::sqrt(K); // TODO: make scaling aware
float alpha = 1.f;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false;
bool output_permute = false;
......@@ -488,10 +482,6 @@ int run(int argc, char* argv[])
return 0;
}
if(alpha != 1.0f)
{
std::cout << "not yet implemented scaling" << std::endl; // TODO: make scaling aware
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
......@@ -24,7 +24,7 @@ template <typename DataType,
typename FloatLSE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -816,13 +816,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4>
template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
......@@ -1083,7 +1084,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
......@@ -1446,11 +1447,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::PassThrough{});
s_element_op);
//
// set up Y dot dY
......@@ -1673,19 +1674,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
else
{
acc_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); });
[&](auto i) { s_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(S_i:)
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
block_sync_lds(); // wait for gemm1 LDS read
......@@ -1783,7 +1785,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
});
// gemm dQ
// dQ = dS * K
// dQ = scalar * dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
......@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
} // end gemm dQ
// dK = dS^T * dQ
// dK = scalar * dS^T * dQ
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
......@@ -2008,7 +2010,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
......@@ -2032,7 +2034,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::PassThrough{}};
s_element_op};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment