Commit 80da57fd authored by letaoqin's avatar letaoqin
Browse files

fix bug for void type

parent 90f8550b
...@@ -110,6 +110,13 @@ __global__ void ...@@ -110,6 +110,13 @@ __global__ void
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
// const index_t global_thread_id = get_thread_global_1d_id(); // const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
...@@ -122,7 +129,7 @@ __global__ void ...@@ -122,7 +129,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0_grid == nullptr ? nullptr : p_d0_grid + d0_batch_offset, tmp_p_d0_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
...@@ -155,7 +162,7 @@ __global__ void ...@@ -155,7 +162,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0_grid == nullptr ? nullptr : p_d0_grid + d0_batch_offset, tmp_p_d0_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
...@@ -618,11 +625,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -618,11 +625,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides)},
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n_)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -633,8 +635,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -633,8 +635,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{ b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -685,10 +685,23 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -685,10 +685,23 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
if constexpr(!is_same<D0DataType, void>::value)
{
d0_grid_desc_m_n_ = Transform::MakeCGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n_);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
}
} }
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
...@@ -1010,15 +1023,19 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1010,15 +1023,19 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return false; return false;
} }
if(arg.d0_n_length_stride_[1] == 1 && if constexpr(!is_same<D0DataType, void>::value)
arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; if(arg.d0_n_length_stride_[1] == 1 &&
} arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1) {
{ return false;
return false; }
if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......
...@@ -992,9 +992,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -992,9 +992,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0)); // register number 0)); // register number
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat DropoutNRepeat, // NRepeat
...@@ -1293,6 +1290,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1293,6 +1290,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// add bias // add bias
if constexpr(!std::is_void<D0DataType>::value) if constexpr(!std::is_void<D0DataType>::value)
{ {
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// get register // get register
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType, D0DataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment