Commit 2f2f5490 authored by letaoqin's avatar letaoqin
Browse files

chech bias pointer is nullptr

parent e0f595de
...@@ -2336,50 +2336,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2336,50 +2336,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf; ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, d0_thread_copy_lds_to_vgpr.Run(
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0_block_buf, make_tuple(I0, I0, I0, I0, I0),
D0Operator::d0_thread_desc_, d0_block_buf,
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_thread_desc_,
d0_thread_buf); make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { // bias add
constexpr index_t c_offset = static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
...@@ -2608,37 +2613,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2608,37 +2613,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// output bias grad // output bias grad
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( if(p_d0grad_grid != nullptr)
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
// write data from lds to global // write data from lds to global
d0_block_copy_lds_to_global.Run( d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf, d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf, d0grad_grid_buf,
I0); I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
}); });
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment