Commit de6af588 authored by letaoqin's avatar letaoqin
Browse files

v2 add bias

parent 8e26a612
...@@ -23,7 +23,7 @@ int run(int argc, char* argv[]) ...@@ -23,7 +23,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1; float p_drop = 0;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -69,7 +69,7 @@ int run(int argc, char* argv[]) ...@@ -69,7 +69,7 @@ int run(int argc, char* argv[])
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f;// / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
...@@ -137,7 +137,7 @@ int run(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1,1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
...@@ -297,6 +297,7 @@ int run(int argc, char* argv[]) ...@@ -297,6 +297,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<DDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
...@@ -310,6 +311,10 @@ int run(int argc, char* argv[]) ...@@ -310,6 +311,10 @@ int run(int argc, char* argv[])
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
z_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
...@@ -322,6 +327,10 @@ int run(int argc, char* argv[]) ...@@ -322,6 +327,10 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
//bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += d_g_m_n(idx);
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......
...@@ -499,8 +499,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -499,8 +499,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t raw_n_padded, const index_t raw_n_padded,
const index_t block_idx_m) const index_t block_idx_m)
{ {
ignore = p_d_grid;
ignore = d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -900,6 +898,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -900,6 +898,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d matrix)
constexpr auto d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // RegisterNum
auto d_threadwise_copy_globla_vgpr =
ThreadwiseTensorSliceTransfer_v2<DDataType,
DDataType,
decltype(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
4,
1,
false>(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
// z is random number matrix for dropout verify // z is random number matrix for dropout verify
// //
// z vgpr copy to global // z vgpr copy to global
...@@ -992,9 +1036,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -992,9 +1036,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static_cast<ushort*>(p_shared), static_cast<ushort*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ushort,
...@@ -1111,6 +1152,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1111,6 +1152,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// add bias
if(p_d_grid)
{
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
d_thread_buf;
d_threadwise_copy_globla_vgpr.Run(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d_grid_buf,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf);
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// acc add bias
acc_thread_buf(i) += d_thread_buf[i];
});
d_threadwise_copy_globla_vgpr.MoveSrcSliceWindow(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
...@@ -1145,7 +1215,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1145,7 +1215,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
{ {
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local = auto m_local =
...@@ -1572,6 +1641,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1572,6 +1641,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
}); });
} }
} }
}; }; // namespace ck
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment