Commit f158f4d4 authored by letaoqin's avatar letaoqin
Browse files

fix comments

parent 80da57fd
...@@ -87,9 +87,6 @@ template <index_t NumDimG, ...@@ -87,9 +87,6 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{ {
static constexpr index_t NumAcc0Bias = 1;
static constexpr index_t NumAcc1Bias = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b0, const void* p_b0,
......
...@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_d0_grid = p_acc0_biases_vec.size() const auto p_d0_grid = p_acc0_biases_vec.size() > 0
? static_cast<const D0DataType*>(p_acc0_biases_vec[i]) ? static_cast<const D0DataType*>(p_acc0_biases_vec[i])
: nullptr; : nullptr;
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
......
...@@ -1288,7 +1288,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1288,7 +1288,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if constexpr(!std::is_void<D0DataType>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment