"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "adbba96255eab11da3c18544645b72188b57f014"
Commit 61f4a7ee authored by Anthony Chang's avatar Anthony Chang
Browse files

implement scaling

parent aa0ee8e2
...@@ -212,11 +212,6 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -212,11 +212,6 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
#endif #endif
// P = Softmax(S) // P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m); auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
...@@ -249,8 +244,7 @@ int run(int argc, char* argv[]) ...@@ -249,8 +244,7 @@ int run(int argc, char* argv[])
ck::index_t G0 = 3; ck::index_t G0 = 3;
ck::index_t G1 = 2; ck::index_t G1 = 2;
// float alpha = 1.f / std::sqrt(K); // TODO: make scaling aware float alpha = 1.f / std::sqrt(K);
float alpha = 1.f;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -488,10 +482,6 @@ int run(int argc, char* argv[]) ...@@ -488,10 +482,6 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
if(alpha != 1.0f)
{
std::cout << "not yet implemented scaling" << std::endl; // TODO: make scaling aware
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -24,7 +24,7 @@ template <typename DataType, ...@@ -24,7 +24,7 @@ template <typename DataType,
typename FloatLSE, typename FloatLSE,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation, typename SElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -816,13 +816,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -816,13 +816,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0)); return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
} }
template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4> template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, DataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4, CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
tensor_operation::element_wise::PassThrough, // CElementwiseOperation ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim 7, // VectorDim
...@@ -1083,7 +1084,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1083,7 +1084,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op, const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op, const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
...@@ -1446,11 +1447,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1446,11 +1447,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
make_multi_index( make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0); I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype( auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>( decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::PassThrough{}); s_element_op);
// //
// set up Y dot dY // set up Y dot dY
...@@ -1673,19 +1674,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1673,19 +1674,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
} }
else else
{ {
acc_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]); s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); }); [&](auto i) { s_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); });
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(S_i:) // P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -1783,7 +1785,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1783,7 +1785,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
}); });
// gemm dQ // gemm dQ
// dQ = dS * K // dQ = scalar * dS * K
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
...@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
} }
} // end gemm dQ } // end gemm dQ
// dK = dS^T * dQ // dK = scalar * dS^T * dQ
v_slash_k_grad_thread_buf.Clear(); v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B // load KGrad Gemm B
...@@ -2008,7 +2010,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -2008,7 +2010,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
I1, I1,
...@@ -2032,7 +2034,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -2032,7 +2034,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
n_thread_data_on_block_idx[I2], n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3], n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]), n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::PassThrough{}}; s_element_op};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment