Commit f90af872 authored by letaoqin's avatar letaoqin
Browse files

fix merge issue

parent 63e3f3c4
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -1332,7 +1332,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1332,7 +1332,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1361,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1361,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadWiseCopy = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1372,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1372,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2, // SrcScalarPerVector 2, // SrcScalarPerVector
2>; 2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
...@@ -2054,7 +2054,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2054,7 +2054,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -2062,10 +2062,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2062,10 +2062,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
......
...@@ -1388,7 +1388,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1388,7 +1388,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1417,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1417,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadWiseCopy = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1511,7 +1511,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1511,7 +1511,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof(GemmDataType) / sizeof(FloatGemmAcc); sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size; D0Operator::template TypeTransform<D0DataType>::Size;
...@@ -2126,7 +2127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2126,7 +2127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -2134,7 +2135,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2134,7 +2135,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment