"vscode:/vscode.git/clone" did not exist on "525af938f6d19a21f931cf9f02dfa9005fa022f9"
Commit 9e49c2bf authored by letaoqin's avatar letaoqin
Browse files

change bias pos

parent de6af588
......@@ -69,7 +69,7 @@ int run(int argc, char* argv[])
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f;// / std::sqrt(K);
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
......
......@@ -1152,35 +1152,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
acc_thread_buf,
num_k_block_main_loop);
// add bias
if(p_d_grid)
{
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
d_thread_buf;
d_threadwise_copy_globla_vgpr.Run(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d_grid_buf,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf);
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// acc add bias
acc_thread_buf(i) += d_thread_buf[i];
});
d_threadwise_copy_globla_vgpr.MoveSrcSliceWindow(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
......@@ -1241,6 +1212,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
if(p_d_grid)
{
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
d_thread_buf;
d_threadwise_copy_globla_vgpr.Run(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d_grid_buf,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf);
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// acc add bias
acc_thread_buf(i) += d_thread_buf[i];
});
d_threadwise_copy_globla_vgpr.MoveSrcSliceWindow(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment